Skip to content

Commit 4c44340

Browse files
hexagon: reduce number of vector stores in matmul output
1 parent fbb01ef commit 4c44340

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

ggml/src/ggml-hexagon/htp/matmul-ops.c

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,12 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
464464
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
465465
}
466466

467-
// Reduce and convert into fp32
468-
r0_sum = hvx_vec_qf32_reduce_sum(r0_sum);
469-
r1_sum = hvx_vec_qf32_reduce_sum(r1_sum);
467+
// Convert into fp32 and reduce
468+
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
469+
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
470+
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
470471

471-
hvx_vec_store_u(&s[0], 4, Q6_Vsf_equals_Vqf32(r0_sum));
472-
hvx_vec_store_u(&s[1], 4, Q6_Vsf_equals_Vqf32(r1_sum));
472+
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
473473
}
474474

475475
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -637,12 +637,12 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
637637
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
638638
}
639639

640-
// Reduce and convert into fp32
641-
r0_sum = hvx_vec_qf32_reduce_sum(r0_sum);
642-
r1_sum = hvx_vec_qf32_reduce_sum(r1_sum);
640+
// Convert into fp32 and reduce
641+
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
642+
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
643+
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
643644

644-
hvx_vec_store_u(&s[0], 4, Q6_Vsf_equals_Vqf32(r0_sum));
645-
hvx_vec_store_u(&s[1], 4, Q6_Vsf_equals_Vqf32(r1_sum));
645+
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
646646
}
647647

648648
static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
@@ -879,12 +879,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
879879
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
880880
}
881881

882-
// Reduce and convert into fp32
883-
r0_sum = hvx_vec_qf32_reduce_sum(r0_sum);
884-
r1_sum = hvx_vec_qf32_reduce_sum(r1_sum);
882+
// Convert into fp32 and reduce
883+
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
884+
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
885+
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
885886

886-
hvx_vec_store_u(&s[0], 4, Q6_Vsf_equals_Vqf32(r0_sum));
887-
hvx_vec_store_u(&s[1], 4, Q6_Vsf_equals_Vqf32(r1_sum));
887+
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
888888
}
889889

890890
#if 1

0 commit comments

Comments
 (0)