Skip to content

Commit 34ce48d

Browse files
authored
ggml-hexagon: fix rope failure at test-backend-ops (#17565)
* fix test failure * fix: correct scaling calculations in rope_cache_init * fix: optimize element copying in rope_hex_f32 using memcpy * fix: optimize loop boundaries in rope_hex_f32 for better performance * feat: add profiling macros for performance measurement in operations
1 parent 45e350e commit 34ce48d

File tree

1 file changed

+37
-41
lines changed

1 file changed

+37
-41
lines changed

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
7373
return (1 - MIN(1, MAX(0, y)));
7474
}
7575

76-
static void rope_cache_init(const float theta_base,
77-
float freq_scale,
78-
const float * freq_factors,
79-
float * corr_dims,
80-
uint32_t ne0,
81-
float ext_factor,
82-
float mscale,
83-
float * cache,
84-
float theta_scale) {
76+
static void rope_cache_init(const float theta_base,
77+
const float freq_scale,
78+
const float * freq_factors,
79+
float * corr_dims,
80+
const uint32_t ne0,
81+
const float ext_factor,
82+
const float mscale,
83+
float * cache,
84+
const float theta_scale) {
8585
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
8686
float theta = theta_base;
8787

@@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base,
9292

9393
// Get n-d rotational scaling corrected for extrapolation
9494
float theta_interp = freq_scale * theta_extrap;
95-
float theta2 = theta_interp;
95+
float theta_final = theta_interp;
96+
float mscale_final = mscale;
9697

9798
if (ext_factor != 0.0f) {
9899
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
99-
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
100+
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
100101

101102
// Get n-d magnitude scaling corrected for interpolation
102-
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
103+
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
103104
}
104105

105-
cache[i0 + 0] = cosf(theta2) * mscale;
106-
cache[i0 + 1] = sinf(theta2) * mscale;
106+
cache[i0 + 0] = cosf(theta_final) * mscale_final;
107+
cache[i0 + 1] = sinf(theta_final) * mscale_final;
107108

108109
theta *= theta_scale;
109110
}
@@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
151152
}
152153

153154
static void hvx_calc_rope_neox_f32(const float * restrict src0,
154-
float * restrict dst,
155-
const int num_elems,
156-
const float * restrict theta_cache) {
155+
float * restrict dst,
156+
const int num_elems,
157+
const float * restrict theta_cache) {
157158
// for (int i = 0; i < num_elems; i += 2) {
158159
//const float cos_theta = theta_cache[i + 0];
159160
//const float sin_theta = theta_cache[i + 1];
@@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
192193
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
193194
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
194195

195-
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
196+
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
196197
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
197198

198199
src0_curr += VLEN;
@@ -259,16 +260,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
259260
const uint32_t ir1,
260261
int nth,
261262
int ith,
262-
int opt_path) {
263+
const int opt_path) {
263264
struct htp_ops_context * octx = rope_ctx->octx;
264265

265266
const struct htp_tensor * src0 = &octx->src0;
266267
const struct htp_tensor * src1 = &octx->src1;
267268
const struct htp_tensor * src2 = &octx->src2;
268269
struct htp_tensor * dst = &octx->dst;
269270

270-
const int32_t mode = rope_ctx->mode;
271-
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
271+
const int32_t mode = rope_ctx->mode;
272+
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
272273

273274
htp_rope_preamble;
274275

@@ -281,23 +282,17 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
281282
freq_factors = (const float *) src2->data;
282283
}
283284

284-
int ir = 0;
285-
285+
const uint32_t i1_end = MIN(ir1, ne1);
286+
const int32_t half_dims = rope_ctx->n_dims / 2;
287+
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
286288
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
287289
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
288290
const int32_t p = pos[i2];
289291

290292
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
291293
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
292294

293-
for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
294-
if (ir++ < ir0) {
295-
continue;
296-
}
297-
if (ir > ir1) {
298-
break;
299-
}
300-
295+
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
301296
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
302297
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
303298

@@ -310,17 +305,20 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
310305
} else {
311306
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
312307
}
308+
309+
src_loc += rope_ctx->n_dims;
310+
dst_data_loc += rope_ctx->n_dims;
313311
} else {
314312
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
315313
const float cos_theta = wp0[i0 + 0];
316314
const float sin_theta = wp0[i0 + 1];
317315

318316
if (is_neox) {
319317
const float x0 = src_loc[0];
320-
const float x1 = src_loc[rope_ctx->n_dims/2];
318+
const float x1 = src_loc[half_dims];
321319

322-
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
323-
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
320+
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
321+
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
324322

325323
src_loc += 1;
326324
dst_data_loc += 1;
@@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
335333
dst_data_loc += 2;
336334
}
337335
}
338-
}
339-
340-
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
341-
dst_data_loc[0] = src_loc[0];
342-
dst_data_loc[1] = src_loc[1];
343336

344-
src_loc += 2;
345-
dst_data_loc += 2;
337+
src_loc += (is_neox ? half_dims : 0);
338+
dst_data_loc += (is_neox ? half_dims : 0);
346339
}
340+
341+
// TODO: use simd to speed up the remaining elements copy
342+
memcpy(dst_data_loc, src_loc, remain_bytes);
347343
}
348344
}
349345
}

0 commit comments

Comments
 (0)