Skip to content

Commit f440d3e

Browse files
hexagon: some more matmul optimizations and comments
Optimize cases where tensor dims are not multiple of 1024 (e.g in Qwen models). We've handled those cases already but at a higher overhead.
1 parent 55ef9c8 commit f440d3e

File tree

1 file changed

+78
-44
lines changed

1 file changed

+78
-44
lines changed

ggml/src/ggml-hexagon/htp/matmul-ops.c

Lines changed: 78 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -243,36 +243,70 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict
243243
return r;
244244
}
245245

246-
static inline HVX_Vector hvx_vec_rmpy_x8(HVX_Vector_x8 x, HVX_Vector_x8 y) {
247-
HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
248-
HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
249-
HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
250-
HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
251-
HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
252-
HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
253-
HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
254-
HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
255-
256-
HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
257-
HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
258-
HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
259-
HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
260-
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
261-
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
262-
r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
263-
r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
264-
265-
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
266-
p1 = Q6_W_vdeal_VVR(r3, r2, -4);
267-
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
268-
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
269-
270-
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
271-
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
246+
// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
247+
// Accumulate each block into a single int32 value.
248+
// Return a single HVX vector with 32x int32 accumulators.
249+
// This version is parameterized to support less than 1024 elements.
250+
// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
251+
252+
static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
253+
HVX_Vector r0 = Q6_V_vsplat_R(0);
254+
HVX_Vector r1 = Q6_V_vsplat_R(0);
255+
HVX_Vector r2 = Q6_V_vsplat_R(0);
256+
HVX_Vector r3 = Q6_V_vsplat_R(0);
257+
HVX_Vector r4 = Q6_V_vsplat_R(0);
258+
HVX_Vector r5 = Q6_V_vsplat_R(0);
259+
HVX_Vector r6 = Q6_V_vsplat_R(0);
260+
HVX_Vector r7 = Q6_V_vsplat_R(0);
261+
262+
HVX_VectorPair p3;
263+
HVX_VectorPair p2;
264+
HVX_VectorPair p1;
265+
HVX_VectorPair p0;
266+
267+
if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
268+
if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
269+
if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
270+
if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
271+
if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
272+
if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
273+
if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
274+
if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
275+
276+
if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
277+
if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
278+
if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
279+
if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
280+
281+
if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
282+
if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
283+
if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
284+
if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
285+
286+
if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
287+
if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
288+
289+
if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
290+
if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
291+
292+
if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
293+
if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
272294

273295
return r0;
274296
}
275297

298+
static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
299+
return hvx_vec_rmpy_x8_n(x, y, 1024);
300+
}
301+
302+
// Handle most common cases of tensors not multiple of 1024.
303+
static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
304+
if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
305+
if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
306+
if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
307+
return hvx_vec_rmpy_x8_n(x, y, 1024);
308+
}
309+
276310
static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
277311
assert(n % 32 == 0); // min sub-block size
278312
assert((unsigned long) vx % 128 == 0);
@@ -309,7 +343,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
309343
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
310344
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
311345

312-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
346+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
313347

314348
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
315349
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -326,7 +360,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
326360
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
327361
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
328362

329-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
363+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
330364

331365
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
332366
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -393,8 +427,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
393427
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
394428
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
395429

396-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
397-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
430+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
431+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
398432

399433
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
400434
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -416,8 +450,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
416450
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
417451
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
418452

419-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
420-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
453+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
454+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
421455

422456
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
423457
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -482,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
482516
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
483517
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
484518

485-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
519+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
486520

487521
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
488522
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -499,7 +533,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
499533
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
500534
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
501535

502-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
536+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
503537

504538
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
505539
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -566,8 +600,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
566600
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
567601
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
568602

569-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
570-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
603+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
604+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
571605

572606
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
573607
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -589,8 +623,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
589623
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
590624
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
591625

592-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
593-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
626+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
627+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
594628

595629
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
596630
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -658,7 +692,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
658692
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
659693
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
660694

661-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
695+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
662696

663697
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
664698
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
@@ -690,7 +724,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
690724
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
691725
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
692726

693-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
727+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
694728

695729
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
696730
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
@@ -772,8 +806,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
772806
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
773807
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
774808

775-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
776-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
809+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
810+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
777811

778812
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
779813
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
@@ -813,8 +847,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
813847
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
814848
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
815849

816-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r0_q, vy_q));
817-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8(r1_q, vy_q));
850+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
851+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
818852

819853
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
820854
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);

0 commit comments

Comments
 (0)