|
24 | 24 | #include "hvx-utils.h" |
25 | 25 | #include "ops-utils.h" |
26 | 26 |
|
| 27 | +// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h |
| 28 | +#define HTP_ROPE_TYPE_NORMAL 0 |
| 29 | +#define HTP_ROPE_TYPE_NEOX 2 |
| 30 | + |
27 | 31 | #define htp_rope_preamble \ |
28 | 32 | const uint32_t ne00 = src0->ne[0]; \ |
29 | 33 | const uint32_t ne01 = src0->ne[1]; \ |
@@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context |
146 | 150 | rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); |
147 | 151 | } |
148 | 152 |
|
| 153 | +static void hvx_calc_rope_neox_f32(const float * restrict src0, |
| 154 | + float * restrict dst, |
| 155 | + const int num_elems, |
| 156 | + const float * restrict theta_cache) { |
| 157 | + // for (int i = 0; i < num_elems; i += 2) { |
| 158 | + //const float cos_theta = theta_cache[i + 0]; |
| 159 | + //const float sin_theta = theta_cache[i + 1]; |
| 160 | + |
| 161 | + //const float x0 = src[0]; |
| 162 | + //const float x1 = src[num_elems/2]; |
| 163 | + |
| 164 | + //dst[0] = x0*cos_theta - x1*sin_theta; |
| 165 | + //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; |
| 166 | + |
| 167 | + //src += 1; |
| 168 | + //dst += 1; |
| 169 | + // } |
| 170 | + |
| 171 | + const uint8_t * restrict src0_curr = (const uint8_t *) src0; |
| 172 | + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; |
| 173 | + uint8_t * restrict dst_curr = (uint8_t *) dst; |
| 174 | + |
| 175 | + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once |
| 176 | + int half_size = (sizeof(float) * (num_elems / 2)); |
| 177 | + |
| 178 | + for (int i = 0; i < step_of_1; i++) { |
| 179 | + HVX_Vector v0 = *(HVX_Vector *) src0_curr; |
| 180 | + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); |
| 181 | + |
| 182 | + HVX_Vector v2 = *(HVX_Vector *) theta_curr; |
| 183 | + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); |
| 184 | + |
| 185 | + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta |
| 186 | + |
| 187 | + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); |
| 188 | + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); |
| 189 | + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); |
| 190 | + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); |
| 191 | + |
| 192 | + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); |
| 193 | + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); |
| 194 | + |
| 195 | + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); |
| 196 | + *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); |
| 197 | + |
| 198 | + src0_curr += VLEN; |
| 199 | + theta_curr += 2 * VLEN; |
| 200 | + dst_curr += VLEN; |
| 201 | + } |
| 202 | +} |
| 203 | + |
149 | 204 | static void hvx_calc_rope_f32(const float * restrict src0, |
150 | 205 | float * restrict dst, |
151 | 206 | const int num_elems, |
@@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, |
212 | 267 | const struct htp_tensor * src2 = &octx->src2; |
213 | 268 | struct htp_tensor * dst = &octx->dst; |
214 | 269 |
|
| 270 | + const int32_t mode = rope_ctx->mode; |
| 271 | + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; |
| 272 | + |
215 | 273 | htp_rope_preamble; |
216 | 274 |
|
217 | 275 | const int32_t * pos = (const int32_t *) src1->data; |
@@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, |
247 | 305 | float * dst_data_loc = dst_data; |
248 | 306 |
|
249 | 307 | if (1 == opt_path) { |
250 | | - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); |
| 308 | + if (is_neox) { |
| 309 | + hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); |
| 310 | + } else { |
| 311 | + hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); |
| 312 | + } |
251 | 313 | } else { |
252 | 314 | for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { |
253 | 315 | const float cos_theta = wp0[i0 + 0]; |
254 | 316 | const float sin_theta = wp0[i0 + 1]; |
255 | 317 |
|
256 | | - const float x0 = src_loc[0]; |
257 | | - const float x1 = src_loc[1]; |
| 318 | + if (is_neox) { |
| 319 | + const float x0 = src_loc[0]; |
| 320 | + const float x1 = src_loc[rope_ctx->n_dims/2]; |
| 321 | + |
| 322 | + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; |
| 323 | + dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; |
| 324 | + |
| 325 | + src_loc += 1; |
| 326 | + dst_data_loc += 1; |
| 327 | + } else { |
| 328 | + const float x0 = src_loc[0]; |
| 329 | + const float x1 = src_loc[1]; |
258 | 330 |
|
259 | | - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; |
260 | | - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; |
| 331 | + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; |
| 332 | + dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; |
261 | 333 |
|
262 | | - src_loc += 2; |
263 | | - dst_data_loc += 2; |
| 334 | + src_loc += 2; |
| 335 | + dst_data_loc += 2; |
| 336 | + } |
264 | 337 | } |
265 | 338 | } |
266 | 339 |
|
|
0 commit comments