@@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
7373 return (1 - MIN (1 , MAX (0 , y )));
7474}
7575
76- static void rope_cache_init (const float theta_base ,
77- float freq_scale ,
78- const float * freq_factors ,
79- float * corr_dims ,
80- uint32_t ne0 ,
81- float ext_factor ,
82- float mscale ,
83- float * cache ,
84- float theta_scale ) {
76+ static void rope_cache_init (const float theta_base ,
77+ const float freq_scale ,
78+ const float * freq_factors ,
79+ float * corr_dims ,
80+ const uint32_t ne0 ,
81+ const float ext_factor ,
82+ const float mscale ,
83+ float * cache ,
84+ const float theta_scale ) {
8585 // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
8686 float theta = theta_base ;
8787
@@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base,
9292
9393 // Get n-d rotational scaling corrected for extrapolation
9494 float theta_interp = freq_scale * theta_extrap ;
95- float theta2 = theta_interp ;
95+ float theta_final = theta_interp ;
96+ float mscale_final = mscale ;
9697
9798 if (ext_factor != 0.0f ) {
9899 float ramp_mix = rope_yarn_ramp (corr_dims [0 ], corr_dims [1 ], i0 ) * ext_factor ;
99- theta2 = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
100+ theta_final = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
100101
101102 // Get n-d magnitude scaling corrected for interpolation
102- mscale *= 1.0f + 0.1f * logf (1.0f / freq_scale );
103+ mscale_final *= 1.0f + 0.1f * logf (1.0f / freq_scale );
103104 }
104105
105- cache [i0 + 0 ] = cosf (theta2 ) * mscale ;
106- cache [i0 + 1 ] = sinf (theta2 ) * mscale ;
106+ cache [i0 + 0 ] = cosf (theta_final ) * mscale_final ;
107+ cache [i0 + 1 ] = sinf (theta_final ) * mscale_final ;
107108
108109 theta *= theta_scale ;
109110 }
@@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
151152}
152153
153154static void hvx_calc_rope_neox_f32 (const float * restrict src0 ,
154- float * restrict dst ,
155- const int num_elems ,
156- const float * restrict theta_cache ) {
155+ float * restrict dst ,
156+ const int num_elems ,
157+ const float * restrict theta_cache ) {
157158 // for (int i = 0; i < num_elems; i += 2) {
158159 //const float cos_theta = theta_cache[i + 0];
159160 //const float sin_theta = theta_cache[i + 1];
@@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
192193 HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32 (vx0_c , vx1_s );
193194 HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32 (vx0_s , vx1_c );
194195
195- * (HVX_Vector * ) dst_curr = Q6_Vsf_equals_Vqf32 (v4 );
196+ * (HVX_Vector * ) dst_curr = Q6_Vsf_equals_Vqf32 (v4 );
196197 * (HVX_Vector * ) (dst_curr + half_size ) = Q6_Vsf_equals_Vqf32 (v5 );
197198
198199 src0_curr += VLEN ;
@@ -259,16 +260,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
259260 const uint32_t ir1 ,
260261 int nth ,
261262 int ith ,
262- int opt_path ) {
263+ const int opt_path ) {
263264 struct htp_ops_context * octx = rope_ctx -> octx ;
264265
265266 const struct htp_tensor * src0 = & octx -> src0 ;
266267 const struct htp_tensor * src1 = & octx -> src1 ;
267268 const struct htp_tensor * src2 = & octx -> src2 ;
268269 struct htp_tensor * dst = & octx -> dst ;
269270
270- const int32_t mode = rope_ctx -> mode ;
271- const bool is_neox = mode & HTP_ROPE_TYPE_NEOX ;
271+ const int32_t mode = rope_ctx -> mode ;
272+ const bool is_neox = mode & HTP_ROPE_TYPE_NEOX ;
272273
273274 htp_rope_preamble ;
274275
@@ -281,23 +282,17 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
281282 freq_factors = (const float * ) src2 -> data ;
282283 }
283284
284- int ir = 0 ;
285-
285+ const uint32_t i1_end = MIN (ir1 , ne1 );
286+ const int32_t half_dims = rope_ctx -> n_dims / 2 ;
287+ const size_t remain_bytes = (ne0 - rope_ctx -> n_dims ) * sizeof (float );
286288 for (uint32_t i3 = 0 ; i3 < ne3 ; i3 ++ ) { // batch
287289 for (uint32_t i2 = 0 ; i2 < ne2 ; i2 ++ ) { // seq-len
288290 const int32_t p = pos [i2 ];
289291
290292 rope_cache_init (p , rope_ctx -> freq_scale , freq_factors , rope_ctx -> corr_dims , ne0 , rope_ctx -> ext_factor ,
291293 rope_ctx -> attn_factor , wp0 , rope_ctx -> theta_scale );
292294
293- for (uint32_t i1 = 0 ; i1 < ne1 ; i1 ++ ) { // attn-heads
294- if (ir ++ < ir0 ) {
295- continue ;
296- }
297- if (ir > ir1 ) {
298- break ;
299- }
300-
295+ for (uint32_t i1 = ir0 ; i1 < i1_end ; i1 ++ ) { // attn-heads
301296 const float * src = (float * ) ((char * ) src0 -> data + i3 * nb03 + i2 * nb02 + i1 * nb01 );
302297 float * dst_data = (float * ) ((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 );
303298
@@ -310,17 +305,20 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
310305 } else {
311306 hvx_calc_rope_f32 (src_loc , dst_data_loc , rope_ctx -> n_dims , wp0 );
312307 }
308+
309+ src_loc += rope_ctx -> n_dims ;
310+ dst_data_loc += rope_ctx -> n_dims ;
313311 } else {
314312 for (uint32_t i0 = 0 ; i0 < rope_ctx -> n_dims ; i0 += 2 ) {
315313 const float cos_theta = wp0 [i0 + 0 ];
316314 const float sin_theta = wp0 [i0 + 1 ];
317315
318316 if (is_neox ) {
319317 const float x0 = src_loc [0 ];
320- const float x1 = src_loc [rope_ctx -> n_dims / 2 ];
318+ const float x1 = src_loc [half_dims ];
321319
322- dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
323- dst_data_loc [rope_ctx -> n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta ;
320+ dst_data_loc [0 ] = x0 * cos_theta - x1 * sin_theta ;
321+ dst_data_loc [half_dims ] = x0 * sin_theta + x1 * cos_theta ;
324322
325323 src_loc += 1 ;
326324 dst_data_loc += 1 ;
@@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
335333 dst_data_loc += 2 ;
336334 }
337335 }
338- }
339-
340- for (uint32_t i0 = rope_ctx -> n_dims ; i0 < ne0 ; i0 += 2 ) {
341- dst_data_loc [0 ] = src_loc [0 ];
342- dst_data_loc [1 ] = src_loc [1 ];
343336
344- src_loc += 2 ;
345- dst_data_loc += 2 ;
337+ src_loc += ( is_neox ? half_dims : 0 ) ;
338+ dst_data_loc += ( is_neox ? half_dims : 0 ) ;
346339 }
340+
341+ // TODO: use simd to speed up the remaining elements copy
342+ memcpy (dst_data_loc , src_loc , remain_bytes );
347343 }
348344 }
349345 }
0 commit comments