@@ -667,14 +667,31 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
667667 float * res_ptr = s ;
668668
669669 for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
670+ // %x[nc] : loop control
671+
670672 const block_q4_0x4 * b_ptr = (const block_q4_0x4 * ) vx + (x * nb );
671673
672674 float32x4_t sumf = vdupq_n_f32 (0 );
675+ // v29 = sumf
676+
673677 for (int l = 0 ; l < nb ; l ++ ) {
678+ // x21 : loop control
679+
680+ // x22 = a_ptr[l].qs
681+ // %x[b_ptr] = b_ptr[l].qs
682+
683+ int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 0 );
684+ int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 );
685+ // (v27, v25) = (a_0, a_1)
686+
674687 uint8x16_t b_0 = vld1q_u8 (b_ptr [l ].qs + 0 );
675688 uint8x16_t b_1 = vld1q_u8 (b_ptr [l ].qs + 16 );
676689 uint8x16_t b_2 = vld1q_u8 (b_ptr [l ].qs + 32 );
677690 uint8x16_t b_3 = vld1q_u8 (b_ptr [l ].qs + 48 );
691+ // (v28, v24, v23, v22) = (b_0, b_1, b_2, b_3)
692+
693+ float16x4_t b_d_half = vld1_f16 ((const float16_t * )b_ptr [l ].d );
694+ // v20 = b_d_half
678695
679696 int8x16_t b_0_hi = vreinterpretq_s8_u8 (b_0 & 0xF0 );
680697 int8x16_t b_0_lo = vreinterpretq_s8_u8 (b_0 << 4 );
@@ -684,11 +701,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
684701 int8x16_t b_2_lo = vreinterpretq_s8_u8 (b_2 << 4 );
685702 int8x16_t b_3_hi = vreinterpretq_s8_u8 (b_3 & 0xF0 );
686703 int8x16_t b_3_lo = vreinterpretq_s8_u8 (b_3 << 4 );
687-
688- int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 0 );
689- int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 );
704+ // (v16, v28) = (b_0_lo, b_0_hi)
705+ // (v19, v24) = (b_0_lo, b_0_hi)
706+ // (v18, v23) = (b_0_lo, b_0_hi)
707+ // (v17, v22) = (b_0_lo, b_0_hi)
690708
691709 int32x4_t sumi = vdupq_n_s32 (0 );
710+ // v26 = sumi
692711 sumi = vdotq_laneq_s32 (sumi , b_0_lo , a_0 , 0 );
693712 sumi = vdotq_laneq_s32 (sumi , b_0_hi , a_1 , 0 );
694713 sumi = vdotq_laneq_s32 (sumi , b_1_lo , a_0 , 1 );
@@ -697,15 +716,21 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
697716 sumi = vdotq_laneq_s32 (sumi , b_2_hi , a_1 , 2 );
698717 sumi = vdotq_laneq_s32 (sumi , b_3_lo , a_0 , 3 );
699718 sumi = vdotq_laneq_s32 (sumi , b_3_hi , a_1 , 3 );
700-
719+
701720 float32x4_t a_d = vcvt_f32_f16 (vld1_dup_f16 ((const float16_t * )& a_ptr [l ].d ));
702- float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
721+ // v21 = a_d
722+
723+ float32x4_t b_d = vcvt_f32_f16 (b_d_half );
724+ // v16 = b_d
725+
703726 float32x4_t d = a_d * b_d ;
727+ // v16 = d
704728
705729 sumf = vmlaq_f32 (sumf , d , vcvtq_n_f32_s32 (sumi , 4 ));
706730 }
707731
708732 vst1q_f32 (res_ptr + x * 4 , sumf );
733+ // %x[res_ptr] = res_ptr + x * 4
709734 }
710735 return ;
711736 }
@@ -1174,7 +1199,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
11741199 sumi = vdotq_laneq_s32 (sumi , b_2_hi , a_1 , 2 );
11751200 sumi = vdotq_laneq_s32 (sumi , b_3_lo , a_0 , 3 );
11761201 sumi = vdotq_laneq_s32 (sumi , b_3_hi , a_1 , 3 );
1177-
1202+
11781203 float32x4_t a_d = vcvt_f32_f16 (vld1_dup_f16 ((const float16_t * )& a_ptr [l ].d ));
11791204 float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
11801205 float32x4_t d = a_d * b_d ;
@@ -1236,7 +1261,97 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12361261
12371262#if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
12381263 if (ggml_cpu_has_neon ()) {
1239- for (int y = 0 ; y < nr / 4 ; y ++ ) {
1264+ #define UNROLL_FACTOR 4
1265+ int y = 0 ;
1266+ for (; y + UNROLL_FACTOR <= nr / 4 ; y += UNROLL_FACTOR ) {
1267+ const block_q8_0x4 * a_ptr [UNROLL_FACTOR ];
1268+ for (int z = 0 ; z < UNROLL_FACTOR ; z ++ ) {
1269+ a_ptr [z ] = (const block_q8_0x4 * ) vy + ((y + z ) * nb );
1270+ }
1271+
1272+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1273+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * ) vx + (x * nb );
1274+
1275+ float32x4_t sumf [UNROLL_FACTOR ][4 ];
1276+ for (int z = 0 ; z < UNROLL_FACTOR ; z ++ ) {
1277+ for (int m = 0 ; m < 4 ; m ++ ) {
1278+ sumf [z ][m ] = vdupq_n_f32 (0 );
1279+ }
1280+ }
1281+ // (v15, v19, v18, v14) = sumf[0][0, 1, 2, 3]
1282+ // (v11, v13, v23, v16) = sumf[1][0, 1, 2, 3]
1283+ // (v27, v7, v0, v4 ) = sumf[2][0, 1, 2, 3]
1284+ // (v5, v21, v8, v1 ) = sumf[3][0, 1, 2, 3]
1285+
1286+ for (int l = 0 ; l < nb ; l ++ ) {
1287+ // x24 : loop control
1288+
1289+ // x28 = b_ptr[l].qs
1290+ // (x25, x23, x22, x21) = a_ptr[0, 1, 2, 3][l].qs
1291+
1292+ int8x16_t b_hi [4 ], b_lo [4 ];
1293+ for (int k = 0 ; k < 4 ; k ++ ) {
1294+ uint8x16_t b = vld1q_u8 (b_ptr [l ].qs + 16 * k );
1295+ b_hi [k ] = vreinterpretq_s8_u8 (b & 0xF0 );
1296+ b_lo [k ] = vreinterpretq_s8_u8 (b << 4 );
1297+ }
1298+ // (v12, v3) = (b_lo[0], b_hi[0])
1299+ // (v31, v22) = (b_lo[1], b_hi[1])
1300+ // (v6, v27) = (b_lo[2], b_hi[2])
1301+ // (v28, v30) = (b_lo[3], b_hi[3])
1302+
1303+ float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
1304+ // v17 = b_d
1305+
1306+ // unroll in ASM
1307+ for (int z = 0 ; z < UNROLL_FACTOR ; z ++ ) {
1308+ int32x4_t sumi [4 ];
1309+ for (int m = 0 ; m < 4 ; m ++ ) {
1310+ sumi [m ] = vdupq_n_s32 (0 );
1311+ }
1312+ // (v10, v29, v9, v20) = sumi[0, 1, 2, 3] (z = 0)
1313+ // (v9, v29, v20, v2) = sumi[0, 1, 2, 3] (z = 1)
1314+ // (v20, v10, v26, v2) = sumi[0, 1, 2, 3] (z = 2)
1315+ // (v26, v10, v2, v19) = sumi[0, 1, 2, 3] (z = 3)
1316+
1317+ for (int k = 0 ; k < 4 ; k ++ ) {
1318+ int8x16_t a0 = vld1q_s8 (a_ptr [z ][l ].qs + 16 * k + 0 );
1319+ sumi [0 ] = vdotq_laneq_s32 (sumi [0 ], b_lo [k ], a0 , 0 );
1320+ sumi [1 ] = vdotq_laneq_s32 (sumi [1 ], b_lo [k ], a0 , 1 );
1321+ sumi [2 ] = vdotq_laneq_s32 (sumi [2 ], b_lo [k ], a0 , 2 );
1322+ sumi [3 ] = vdotq_laneq_s32 (sumi [3 ], b_lo [k ], a0 , 3 );
1323+ }
1324+ for (int k = 0 ; k < 4 ; k ++ ) {
1325+ int8x16_t a1 = vld1q_s8 (a_ptr [z ][l ].qs + 16 * k + 64 );
1326+ sumi [0 ] = vdotq_laneq_s32 (sumi [0 ], b_hi [k ], a1 , 0 );
1327+ sumi [1 ] = vdotq_laneq_s32 (sumi [1 ], b_hi [k ], a1 , 1 );
1328+ sumi [2 ] = vdotq_laneq_s32 (sumi [2 ], b_hi [k ], a1 , 2 );
1329+ sumi [3 ] = vdotq_laneq_s32 (sumi [3 ], b_hi [k ], a1 , 3 );
1330+ }
1331+
1332+ float32x4_t a_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )a_ptr [z ][l ].d ));
1333+ // (v2, v26, v29, v20) = a_d (z = 0, 1, 2, 3)
1334+
1335+ sumf [z ][0 ] = vmlaq_f32 (sumf [z ][0 ], vmulq_laneq_f32 (b_d , a_d , 0 ), vcvtq_n_f32_s32 (sumi [0 ], 4 ));
1336+ sumf [z ][1 ] = vmlaq_f32 (sumf [z ][1 ], vmulq_laneq_f32 (b_d , a_d , 1 ), vcvtq_n_f32_s32 (sumi [1 ], 4 ));
1337+ sumf [z ][2 ] = vmlaq_f32 (sumf [z ][2 ], vmulq_laneq_f32 (b_d , a_d , 2 ), vcvtq_n_f32_s32 (sumi [2 ], 4 ));
1338+ sumf [z ][3 ] = vmlaq_f32 (sumf [z ][3 ], vmulq_laneq_f32 (b_d , a_d , 3 ), vcvtq_n_f32_s32 (sumi [3 ], 4 ));
1339+ }
1340+
1341+ }
1342+
1343+ for (int z = 0 ; z < UNROLL_FACTOR ; z ++ ) {
1344+ for (int m = 0 ; m < 4 ; m ++ ) {
1345+ vst1q_f32 (s + ((y + z ) * 4 + m ) * bs + x * 4 , sumf [z ][m ]);
1346+ }
1347+ }
1348+ }
1349+ }
1350+ #undef UNROLL_FACTOR
1351+
1352+ for (; y < nr / 4 ; y ++ ) {
1353+ // x10 : loop control
1354+
12401355 const block_q8_0x4 * a_ptr = (const block_q8_0x4 * ) vy + (y * nb );
12411356 for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
12421357 const block_q4_0x4 * b_ptr = (const block_q4_0x4 * ) vx + (x * nb );
@@ -1245,32 +1360,68 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12451360 for (int m = 0 ; m < 4 ; m ++ ) {
12461361 sumf [m ] = vdupq_n_f32 (0 );
12471362 }
1248-
1363+ // (v15, v19, v18, v14) = sumf[0, 1, 2, 3]
1364+
12491365 for (int l = 0 ; l < nb ; l ++ ) {
1366+ // x21 : loop control
1367+
1368+ // x25 = a_ptr[l].qs
1369+ // x24 = b_ptr[l].qs
1370+
1371+ int8x16_t a_0 [4 ], a_1 [4 ];
1372+ a_0 [0 ] = vld1q_s8 (a_ptr [l ].qs + 0 );
1373+ a_0 [1 ] = vld1q_s8 (a_ptr [l ].qs + 16 );
1374+ a_0 [2 ] = vld1q_s8 (a_ptr [l ].qs + 32 );
1375+ a_0 [3 ] = vld1q_s8 (a_ptr [l ].qs + 48 );
1376+ a_1 [0 ] = vld1q_s8 (a_ptr [l ].qs + 64 );
1377+ a_1 [1 ] = vld1q_s8 (a_ptr [l ].qs + 80 );
1378+ a_1 [2 ] = vld1q_s8 (a_ptr [l ].qs + 96 );
1379+ a_1 [3 ] = vld1q_s8 (a_ptr [l ].qs + 112 );
1380+ // (v5, v26) = (a_0[0], a_1[0])
1381+ // (v2, v25) = (a_0[0], a_1[0])
1382+ // (v31, v24) = (a_0[0], a_1[0])
1383+ // (v27, v16) = (a_0[0], a_1[0])
1384+
1385+ uint8x16_t b_0 = vld1q_u8 (b_ptr [l ].qs + 0 );
1386+ uint8x16_t b_1 = vld1q_u8 (b_ptr [l ].qs + 16 );
1387+ uint8x16_t b_2 = vld1q_u8 (b_ptr [l ].qs + 32 );
1388+ uint8x16_t b_3 = vld1q_u8 (b_ptr [l ].qs + 48 );
1389+ // (v7, v3, v13, v28) = (b_0, b_1, b_2, b_3)
1390+
1391+ int8x16_t b_lo [4 ], b_hi [4 ];
1392+ b_hi [0 ] = vreinterpretq_s8_u8 (b_0 & 0xF0 );
1393+ b_lo [0 ] = vreinterpretq_s8_u8 (b_0 << 4 );
1394+ b_hi [1 ] = vreinterpretq_s8_u8 (b_1 & 0xF0 );
1395+ b_lo [1 ] = vreinterpretq_s8_u8 (b_1 << 4 );
1396+ b_hi [2 ] = vreinterpretq_s8_u8 (b_2 & 0xF0 );
1397+ b_lo [2 ] = vreinterpretq_s8_u8 (b_2 << 4 );
1398+ b_hi [3 ] = vreinterpretq_s8_u8 (b_3 & 0xF0 );
1399+ b_lo [3 ] = vreinterpretq_s8_u8 (b_3 << 4 );
1400+ // (v20, v7) = (b_lo[0], b_hi[0])
1401+ // (v17, v3) = (b_lo[1], b_hi[1])
1402+ // (v22, v13) = (b_lo[2], b_hi[2])
1403+ // (v9, v28) = (b_lo[3], b_hi[3])
1404+
12501405 float32x4_t a_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )a_ptr [l ].d ));
1406+ // v12 = a_d
12511407 float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
1408+ // v21 = b_d
12521409
12531410 int32x4_t sumi_0 = vdupq_n_s32 (0 );
12541411 int32x4_t sumi_1 = vdupq_n_s32 (0 );
12551412 int32x4_t sumi_2 = vdupq_n_s32 (0 );
12561413 int32x4_t sumi_3 = vdupq_n_s32 (0 );
1414+ // (v4, v1, v0, v30) = (sumi_0, sumi_1, sumi_2, sumi_3)
12571415
12581416 for (int k = 0 ; k < 4 ; k ++ ) {
1259- int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 0 );
1260- int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 64 );
1261-
1262- uint8x16_t b = vld1q_u8 (b_ptr [l ].qs + 16 * k );
1263- int8x16_t b_hi = vreinterpretq_s8_u8 (b & 0xF0 );
1264- int8x16_t b_lo = vreinterpretq_s8_u8 (b << 4 );
1265-
1266- sumi_0 = vdotq_laneq_s32 (sumi_0 , b_lo , a_0 , 0 );
1267- sumi_1 = vdotq_laneq_s32 (sumi_1 , b_lo , a_0 , 1 );
1268- sumi_2 = vdotq_laneq_s32 (sumi_2 , b_lo , a_0 , 2 );
1269- sumi_3 = vdotq_laneq_s32 (sumi_3 , b_lo , a_0 , 3 );
1270- sumi_0 = vdotq_laneq_s32 (sumi_0 , b_hi , a_1 , 0 );
1271- sumi_1 = vdotq_laneq_s32 (sumi_1 , b_hi , a_1 , 1 );
1272- sumi_2 = vdotq_laneq_s32 (sumi_2 , b_hi , a_1 , 2 );
1273- sumi_3 = vdotq_laneq_s32 (sumi_3 , b_hi , a_1 , 3 );
1417+ sumi_0 = vdotq_laneq_s32 (sumi_0 , b_lo [k ], a_0 [k ], 0 );
1418+ sumi_1 = vdotq_laneq_s32 (sumi_1 , b_lo [k ], a_0 [k ], 1 );
1419+ sumi_2 = vdotq_laneq_s32 (sumi_2 , b_lo [k ], a_0 [k ], 2 );
1420+ sumi_3 = vdotq_laneq_s32 (sumi_3 , b_lo [k ], a_0 [k ], 3 );
1421+ sumi_0 = vdotq_laneq_s32 (sumi_0 , b_hi [k ], a_1 [k ], 0 );
1422+ sumi_1 = vdotq_laneq_s32 (sumi_1 , b_hi [k ], a_1 [k ], 1 );
1423+ sumi_2 = vdotq_laneq_s32 (sumi_2 , b_hi [k ], a_1 [k ], 2 );
1424+ sumi_3 = vdotq_laneq_s32 (sumi_3 , b_hi [k ], a_1 [k ], 3 );
12741425 }
12751426
12761427 sumf [0 ] = vmlaq_f32 (sumf [0 ], vmulq_laneq_f32 (b_d , a_d , 0 ), vcvtq_n_f32_s32 (sumi_0 , 4 ));
@@ -1279,6 +1430,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12791430 sumf [3 ] = vmlaq_f32 (sumf [3 ], vmulq_laneq_f32 (b_d , a_d , 3 ), vcvtq_n_f32_s32 (sumi_3 , 4 ));
12801431 }
12811432
1433+ // NOTE: asm version has addition code to handle `nr` is not multiple of 4
12821434 for (int m = 0 ; m < 4 ; m ++ ) {
12831435 vst1q_f32 (s + (y * 4 + m ) * bs + x * 4 , sumf [m ]);
12841436 }
@@ -3230,7 +3382,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
32303382 for (int m = 0 ; m < 4 ; m ++ ) {
32313383 sumf [m ] = vdupq_n_f32 (0 );
32323384 }
3233-
3385+
32343386 for (int l = 0 ; l < nb ; l ++ ) {
32353387 float32x4_t a_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )a_ptr [l ].d ));
32363388 float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
@@ -3244,7 +3396,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
32443396 int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 0 );
32453397 int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 64 );
32463398
3247- uint8x16_t b = vld1q_u8 (b_ptr [l ].qs + 16 * k );
3399+ uint8x16_t b = vld1q_u8 (b_ptr [l ].qs + 16 * k );
32483400 int8x16_t b_hi = vqtbl1q_s8 (kvalues , b >> 4 );
32493401 int8x16_t b_lo = vqtbl1q_s8 (kvalues , b & 0xF );
32503402
0 commit comments