Skip to content

Commit c7a54d1

Browse files
committed
[ggml-aarch64] impl the same logic as the ASM version in q4_0_4_4 gemm/gemv
1 parent 32e0862 commit c7a54d1

File tree

1 file changed

+177
-25
lines changed

1 file changed

+177
-25
lines changed

ggml/src/ggml-aarch64.c

Lines changed: 177 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -667,14 +667,31 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
667667
float * res_ptr = s;
668668

669669
for (int x = 0; x < nc / ncols_interleaved; x++) {
670+
// %x[nc] : loop control
671+
670672
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
671673

672674
float32x4_t sumf = vdupq_n_f32(0);
675+
// v29 = sumf
676+
673677
for (int l = 0; l < nb; l++) {
678+
// x21 : loop control
679+
680+
// x22 = a_ptr[l].qs
681+
// %x[b_ptr] = b_ptr[l].qs
682+
683+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
684+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
685+
// (v27, v25) = (a_0, a_1)
686+
674687
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
675688
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
676689
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
677690
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
691+
// (v28, v24, v23, v22) = (b_0, b_1, b_2, b_3)
692+
693+
float16x4_t b_d_half = vld1_f16((const float16_t *)b_ptr[l].d);
694+
// v20 = b_d_half
678695

679696
int8x16_t b_0_hi = vreinterpretq_s8_u8(b_0 & 0xF0);
680697
int8x16_t b_0_lo = vreinterpretq_s8_u8(b_0 << 4);
@@ -684,11 +701,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
684701
int8x16_t b_2_lo = vreinterpretq_s8_u8(b_2 << 4);
685702
int8x16_t b_3_hi = vreinterpretq_s8_u8(b_3 & 0xF0);
686703
int8x16_t b_3_lo = vreinterpretq_s8_u8(b_3 << 4);
687-
688-
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
689-
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
704+
// (v16, v28) = (b_0_lo, b_0_hi)
705+
// (v19, v24) = (b_0_lo, b_0_hi)
706+
// (v18, v23) = (b_0_lo, b_0_hi)
707+
// (v17, v22) = (b_0_lo, b_0_hi)
690708

691709
int32x4_t sumi = vdupq_n_s32(0);
710+
// v26 = sumi
692711
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
693712
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
694713
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
@@ -697,15 +716,21 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
697716
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
698717
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
699718
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
700-
719+
701720
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
702-
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
721+
// v21 = a_d
722+
723+
float32x4_t b_d = vcvt_f32_f16(b_d_half);
724+
// v16 = b_d
725+
703726
float32x4_t d = a_d * b_d;
727+
// v16 = d
704728

705729
sumf = vmlaq_f32(sumf, d, vcvtq_n_f32_s32(sumi, 4));
706730
}
707731

708732
vst1q_f32(res_ptr + x * 4, sumf);
733+
// %x[res_ptr] = res_ptr + x * 4
709734
}
710735
return;
711736
}
@@ -1174,7 +1199,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
11741199
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
11751200
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
11761201
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
1177-
1202+
11781203
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
11791204
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
11801205
float32x4_t d = a_d * b_d;
@@ -1236,7 +1261,97 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12361261

12371262
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
12381263
if (ggml_cpu_has_neon()) {
1239-
for (int y = 0; y < nr / 4; y++) {
1264+
#define UNROLL_FACTOR 4
1265+
int y = 0;
1266+
for (; y + UNROLL_FACTOR <= nr / 4; y += UNROLL_FACTOR) {
1267+
const block_q8_0x4 * a_ptr[UNROLL_FACTOR];
1268+
for (int z = 0; z < UNROLL_FACTOR; z++) {
1269+
a_ptr[z] = (const block_q8_0x4 *) vy + ((y + z) * nb);
1270+
}
1271+
1272+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1273+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1274+
1275+
float32x4_t sumf[UNROLL_FACTOR][4];
1276+
for (int z = 0; z < UNROLL_FACTOR; z ++) {
1277+
for (int m = 0; m < 4; m++) {
1278+
sumf[z][m] = vdupq_n_f32(0);
1279+
}
1280+
}
1281+
// (v15, v19, v18, v14) = sumf[0][0, 1, 2, 3]
1282+
// (v11, v13, v23, v16) = sumf[1][0, 1, 2, 3]
1283+
// (v27, v7, v0, v4 ) = sumf[2][0, 1, 2, 3]
1284+
// (v5, v21, v8, v1 ) = sumf[3][0, 1, 2, 3]
1285+
1286+
for (int l = 0; l < nb; l++) {
1287+
// x24 : loop control
1288+
1289+
// x28 = b_ptr[l].qs
1290+
// (x25, x23, x22, x21) = a_ptr[0, 1, 2, 3][l].qs
1291+
1292+
int8x16_t b_hi[4], b_lo[4];
1293+
for (int k = 0; k < 4; k++) {
1294+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
1295+
b_hi[k] = vreinterpretq_s8_u8(b & 0xF0);
1296+
b_lo[k] = vreinterpretq_s8_u8(b << 4);
1297+
}
1298+
// (v12, v3) = (b_lo[0], b_hi[0])
1299+
// (v31, v22) = (b_lo[1], b_hi[1])
1300+
// (v6, v27) = (b_lo[2], b_hi[2])
1301+
// (v28, v30) = (b_lo[3], b_hi[3])
1302+
1303+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
1304+
// v17 = b_d
1305+
1306+
// unroll in ASM
1307+
for (int z = 0; z < UNROLL_FACTOR; z++) {
1308+
int32x4_t sumi[4];
1309+
for (int m = 0; m < 4; m++) {
1310+
sumi[m] = vdupq_n_s32(0);
1311+
}
1312+
// (v10, v29, v9, v20) = sumi[0, 1, 2, 3] (z = 0)
1313+
// (v9, v29, v20, v2) = sumi[0, 1, 2, 3] (z = 1)
1314+
// (v20, v10, v26, v2) = sumi[0, 1, 2, 3] (z = 2)
1315+
// (v26, v10, v2, v19) = sumi[0, 1, 2, 3] (z = 3)
1316+
1317+
for (int k = 0; k < 4; k++) {
1318+
int8x16_t a0 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 0);
1319+
sumi[0] = vdotq_laneq_s32(sumi[0], b_lo[k], a0, 0);
1320+
sumi[1] = vdotq_laneq_s32(sumi[1], b_lo[k], a0, 1);
1321+
sumi[2] = vdotq_laneq_s32(sumi[2], b_lo[k], a0, 2);
1322+
sumi[3] = vdotq_laneq_s32(sumi[3], b_lo[k], a0, 3);
1323+
}
1324+
for (int k = 0; k < 4; k++) {
1325+
int8x16_t a1 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 64);
1326+
sumi[0] = vdotq_laneq_s32(sumi[0], b_hi[k], a1, 0);
1327+
sumi[1] = vdotq_laneq_s32(sumi[1], b_hi[k], a1, 1);
1328+
sumi[2] = vdotq_laneq_s32(sumi[2], b_hi[k], a1, 2);
1329+
sumi[3] = vdotq_laneq_s32(sumi[3], b_hi[k], a1, 3);
1330+
}
1331+
1332+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[z][l].d));
1333+
// (v2, v26, v29, v20) = a_d (z = 0, 1, 2, 3)
1334+
1335+
sumf[z][0] = vmlaq_f32(sumf[z][0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi[0], 4));
1336+
sumf[z][1] = vmlaq_f32(sumf[z][1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_n_f32_s32(sumi[1], 4));
1337+
sumf[z][2] = vmlaq_f32(sumf[z][2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_n_f32_s32(sumi[2], 4));
1338+
sumf[z][3] = vmlaq_f32(sumf[z][3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi[3], 4));
1339+
}
1340+
1341+
}
1342+
1343+
for (int z = 0; z < UNROLL_FACTOR; z++) {
1344+
for (int m = 0; m < 4; m++) {
1345+
vst1q_f32(s + ((y + z) * 4 + m) * bs + x * 4, sumf[z][m]);
1346+
}
1347+
}
1348+
}
1349+
}
1350+
#undef UNROLL_FACTOR
1351+
1352+
for (; y < nr / 4; y++) {
1353+
// x10 : loop control
1354+
12401355
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
12411356
for (int x = 0; x < nc / ncols_interleaved; x++) {
12421357
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
@@ -1245,32 +1360,68 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12451360
for (int m = 0; m < 4; m++) {
12461361
sumf[m] = vdupq_n_f32(0);
12471362
}
1248-
1363+
// (v15, v19, v18, v14) = sumf[0, 1, 2, 3]
1364+
12491365
for (int l = 0; l < nb; l++) {
1366+
// x21 : loop control
1367+
1368+
// x25 = a_ptr[l].qs
1369+
// x24 = b_ptr[l].qs
1370+
1371+
int8x16_t a_0[4], a_1[4];
1372+
a_0[0] = vld1q_s8(a_ptr[l].qs + 0);
1373+
a_0[1] = vld1q_s8(a_ptr[l].qs + 16);
1374+
a_0[2] = vld1q_s8(a_ptr[l].qs + 32);
1375+
a_0[3] = vld1q_s8(a_ptr[l].qs + 48);
1376+
a_1[0] = vld1q_s8(a_ptr[l].qs + 64);
1377+
a_1[1] = vld1q_s8(a_ptr[l].qs + 80);
1378+
a_1[2] = vld1q_s8(a_ptr[l].qs + 96);
1379+
a_1[3] = vld1q_s8(a_ptr[l].qs + 112);
1380+
// (v5, v26) = (a_0[0], a_1[0])
1381+
// (v2, v25) = (a_0[0], a_1[0])
1382+
// (v31, v24) = (a_0[0], a_1[0])
1383+
// (v27, v16) = (a_0[0], a_1[0])
1384+
1385+
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
1386+
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
1387+
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
1388+
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
1389+
// (v7, v3, v13, v28) = (b_0, b_1, b_2, b_3)
1390+
1391+
int8x16_t b_lo[4], b_hi[4];
1392+
b_hi[0] = vreinterpretq_s8_u8(b_0 & 0xF0);
1393+
b_lo[0] = vreinterpretq_s8_u8(b_0 << 4);
1394+
b_hi[1] = vreinterpretq_s8_u8(b_1 & 0xF0);
1395+
b_lo[1] = vreinterpretq_s8_u8(b_1 << 4);
1396+
b_hi[2] = vreinterpretq_s8_u8(b_2 & 0xF0);
1397+
b_lo[2] = vreinterpretq_s8_u8(b_2 << 4);
1398+
b_hi[3] = vreinterpretq_s8_u8(b_3 & 0xF0);
1399+
b_lo[3] = vreinterpretq_s8_u8(b_3 << 4);
1400+
// (v20, v7) = (b_lo[0], b_hi[0])
1401+
// (v17, v3) = (b_lo[1], b_hi[1])
1402+
// (v22, v13) = (b_lo[2], b_hi[2])
1403+
// (v9, v28) = (b_lo[3], b_hi[3])
1404+
12501405
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
1406+
// v12 = a_d
12511407
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
1408+
// v21 = b_d
12521409

12531410
int32x4_t sumi_0 = vdupq_n_s32(0);
12541411
int32x4_t sumi_1 = vdupq_n_s32(0);
12551412
int32x4_t sumi_2 = vdupq_n_s32(0);
12561413
int32x4_t sumi_3 = vdupq_n_s32(0);
1414+
// (v4, v1, v0, v30) = (sumi_0, sumi_1, sumi_2, sumi_3)
12571415

12581416
for (int k = 0; k < 4; k++) {
1259-
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
1260-
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
1261-
1262-
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
1263-
int8x16_t b_hi = vreinterpretq_s8_u8(b & 0xF0);
1264-
int8x16_t b_lo = vreinterpretq_s8_u8(b << 4);
1265-
1266-
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
1267-
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
1268-
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
1269-
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
1270-
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
1271-
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
1272-
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
1273-
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
1417+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo[k], a_0[k], 0);
1418+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo[k], a_0[k], 1);
1419+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo[k], a_0[k], 2);
1420+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo[k], a_0[k], 3);
1421+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi[k], a_1[k], 0);
1422+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi[k], a_1[k], 1);
1423+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi[k], a_1[k], 2);
1424+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi[k], a_1[k], 3);
12741425
}
12751426

12761427
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi_0, 4));
@@ -1279,6 +1430,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
12791430
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi_3, 4));
12801431
}
12811432

1433+
// NOTE: asm version has addition code to handle `nr` is not multiple of 4
12821434
for (int m = 0; m < 4; m++) {
12831435
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
12841436
}
@@ -3230,7 +3382,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
32303382
for (int m = 0; m < 4; m++) {
32313383
sumf[m] = vdupq_n_f32(0);
32323384
}
3233-
3385+
32343386
for (int l = 0; l < nb; l++) {
32353387
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
32363388
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
@@ -3244,7 +3396,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
32443396
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
32453397
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
32463398

3247-
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3399+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
32483400
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
32493401
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
32503402

0 commit comments

Comments
 (0)