Skip to content

Commit 923ae3c

Browse files
hexagon: add support for ROPE_NEOX (ggml-org#17458)
1 parent 01ad35e commit 923ae3c

File tree

2 files changed

+81
-8
lines changed

2 files changed

+81
-8
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2229,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
22292229

22302230
int mode = op_params[2];
22312231

2232-
if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
2232+
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
22332233
return false;
22342234
}
22352235
if (mode & 1) {

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#include "hvx-utils.h"
2525
#include "ops-utils.h"
2626

27+
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
28+
#define HTP_ROPE_TYPE_NORMAL 0
29+
#define HTP_ROPE_TYPE_NEOX 2
30+
2731
#define htp_rope_preamble \
2832
const uint32_t ne00 = src0->ne[0]; \
2933
const uint32_t ne01 = src0->ne[1]; \
@@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
146150
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
147151
}
148152

153+
static void hvx_calc_rope_neox_f32(const float * restrict src0,
154+
float * restrict dst,
155+
const int num_elems,
156+
const float * restrict theta_cache) {
157+
// for (int i = 0; i < num_elems; i += 2) {
158+
//const float cos_theta = theta_cache[i + 0];
159+
//const float sin_theta = theta_cache[i + 1];
160+
161+
//const float x0 = src[0];
162+
//const float x1 = src[num_elems/2];
163+
164+
//dst[0] = x0*cos_theta - x1*sin_theta;
165+
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
166+
167+
//src += 1;
168+
//dst += 1;
169+
// }
170+
171+
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
172+
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
173+
uint8_t * restrict dst_curr = (uint8_t *) dst;
174+
175+
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
176+
int half_size = (sizeof(float) * (num_elems / 2));
177+
178+
for (int i = 0; i < step_of_1; i++) {
179+
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
180+
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
181+
182+
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
183+
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
184+
185+
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
186+
187+
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
188+
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
189+
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
190+
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
191+
192+
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
193+
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
194+
195+
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
196+
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
197+
198+
src0_curr += VLEN;
199+
theta_curr += 2 * VLEN;
200+
dst_curr += VLEN;
201+
}
202+
}
203+
149204
static void hvx_calc_rope_f32(const float * restrict src0,
150205
float * restrict dst,
151206
const int num_elems,
@@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
212267
const struct htp_tensor * src2 = &octx->src2;
213268
struct htp_tensor * dst = &octx->dst;
214269

270+
const int32_t mode = rope_ctx->mode;
271+
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
272+
215273
htp_rope_preamble;
216274

217275
const int32_t * pos = (const int32_t *) src1->data;
@@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
247305
float * dst_data_loc = dst_data;
248306

249307
if (1 == opt_path) {
250-
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
308+
if (is_neox) {
309+
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
310+
} else {
311+
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
312+
}
251313
} else {
252314
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
253315
const float cos_theta = wp0[i0 + 0];
254316
const float sin_theta = wp0[i0 + 1];
255317

256-
const float x0 = src_loc[0];
257-
const float x1 = src_loc[1];
318+
if (is_neox) {
319+
const float x0 = src_loc[0];
320+
const float x1 = src_loc[rope_ctx->n_dims/2];
321+
322+
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
323+
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
324+
325+
src_loc += 1;
326+
dst_data_loc += 1;
327+
} else {
328+
const float x0 = src_loc[0];
329+
const float x1 = src_loc[1];
258330

259-
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
260-
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
331+
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
332+
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
261333

262-
src_loc += 2;
263-
dst_data_loc += 2;
334+
src_loc += 2;
335+
dst_data_loc += 2;
336+
}
264337
}
265338
}
266339

0 commit comments

Comments
 (0)