@@ -3748,7 +3748,7 @@ kernel void kernel_rope_norm(
37483748
37493749 const float theta = theta_base * pow (args.freq_base , inv_ndims*i0);
37503750
3751- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
3751+ const float freq_factor = args. src2 ? ((device const float *) src2)[ic] : 1 .0f ;
37523752
37533753 rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
37543754
@@ -3801,7 +3801,7 @@ kernel void kernel_rope_neox(
38013801
38023802 const float theta = theta_base * pow (args.freq_base , inv_ndims*i0);
38033803
3804- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
3804+ const float freq_factor = args. src2 ? ((device const float *) src2)[ic] : 1 .0f ;
38053805
38063806 rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
38073807
@@ -3872,7 +3872,7 @@ kernel void kernel_rope_multi(
38723872
38733873 const float theta = theta_base * pow (args.freq_base , inv_ndims*i0);
38743874
3875- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
3875+ const float freq_factor = args. src2 ? ((device const float *) src2)[ic] : 1 .0f ;
38763876
38773877 rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
38783878
@@ -3939,7 +3939,7 @@ kernel void kernel_rope_vision(
39393939 const float theta = theta_base * pow (args.freq_base , 2 .0f * inv_ndims * p);
39403940 // end of mrope
39413941
3942- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
3942+ const float freq_factor = args. src2 ? ((device const float *) src2)[ic] : 1 .0f ;
39433943
39443944 rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
39453945
0 commit comments