Skip to content

Commit b49e3b1

Browse files
POWER : Implement MlasGemmQuantKernel using VSX builtins for M = 1 (microsoft#25490)
POWER : Added a VSX-based implementation of MlasGemmQuantKernel optimized for the case when M = 1. Verified correctness using ONNX Runtime's built-in tests and onnxruntime_mlas_tests;no regressions observed. Evaluated performance using a Granite 8-bit quantized model and observed approximately 3-5% improvement in token generation speed. ### Description when M=1 then performed a multiplication using a VSX vector builtin vec_msum ### Motivation and Context To improve token generation performance for models with a batch size of 1
1 parent 7493b8b commit b49e3b1

File tree

1 file changed

+229
-39
lines changed

1 file changed

+229
-39
lines changed

onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp

Lines changed: 229 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -437,49 +437,51 @@ MlasGemmQuantCopyPackA8x8(
437437
Vtype a2 = vmask;
438438
Vtype a3 = vmask;
439439
Vtype a1 = *reinterpret_cast<const Vtype *>(&a[0]);
440-
if (CountM == 3) {
441-
a3 = *reinterpret_cast<const Vtype *>(&a[lda * 2]);
442-
}
443-
if (CountM >= 2) {
440+
if (CountM == 1) {
441+
vec_t va1 = AIsSigned ? reinterpret_cast<vec_t>(a1) : reinterpret_cast<vec_t>(vec_sub(a1, vmask));
442+
*reinterpret_cast<vec_t *>(&D[0]) = (vec_t)va1;
443+
vsum = vec_sum4s(va1, vsum);
444+
} else {
444445
a2 = *reinterpret_cast<const Vtype *>(&a[lda]);
446+
if (CountM == 3) {
447+
a3 = *reinterpret_cast<const Vtype *>(&a[lda * 2]);
448+
}
449+
Vtype vx =
450+
reinterpret_cast<Vtype>(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2)));
451+
Vtype vx1 =
452+
reinterpret_cast<Vtype>(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4)));
453+
Vtype vx2 =
454+
reinterpret_cast<Vtype>(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2)));
455+
Vtype vx3 =
456+
reinterpret_cast<Vtype>(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4)));
457+
Vtype vx4 = vec_xxpermdi(vx, vx1, 0);
458+
Vtype vx5 = vec_xxpermdi(vx2, vx3, 0);
459+
Vtype vx6 = vec_xxpermdi(vx, vx1, 3);
460+
Vtype vx7 = vec_xxpermdi(vx2, vx3, 3);
461+
vec_t vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx4) : reinterpret_cast<vec_t>(vec_sub(vx4, vmask));
462+
*reinterpret_cast<vec_t *>(&D[0]) = vx0;
463+
vsum = vec_sum4s(vx0, vsum);
464+
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx5) : reinterpret_cast<vec_t>(vec_sub(vx5, vmask));
465+
*reinterpret_cast<vec_t *>(&D[16]) = vx0;
466+
vsum = vec_sum4s(vx0, vsum);
467+
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx6) : reinterpret_cast<vec_t>(vec_sub(vx6, vmask));
468+
*reinterpret_cast<vec_t *>(&D[32]) = vx0;
469+
vsum = vec_sum4s(vx0, vsum);
470+
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx7) : reinterpret_cast<vec_t>(vec_sub(vx7, vmask));
471+
*reinterpret_cast<vec_t *>(&D[48]) = vx0;
472+
vsum = vec_sum4s(vx0, vsum);
473+
}
474+
if (CountM == 1) {
475+
D += 16;
476+
} else {
477+
D += 16 * 4;
445478
}
446-
Vtype vx =
447-
reinterpret_cast<Vtype>(vec_mergee(reinterpret_cast<__vector int>(a1),
448-
reinterpret_cast<__vector int>(a2)));
449-
Vtype vx1 =
450-
reinterpret_cast<Vtype>(vec_mergee(reinterpret_cast<__vector int>(a3),
451-
reinterpret_cast<__vector int>(a4)));
452-
Vtype vx2 =
453-
reinterpret_cast<Vtype>(vec_mergeo(reinterpret_cast<__vector int>(a1),
454-
reinterpret_cast<__vector int>(a2)));
455-
Vtype vx3 =
456-
reinterpret_cast<Vtype>(vec_mergeo(reinterpret_cast<__vector int>(a3),
457-
reinterpret_cast<__vector int>(a4)));
458-
Vtype vx4 = vec_xxpermdi(vx, vx1, 0);
459-
Vtype vx5 = vec_xxpermdi(vx2, vx3, 0);
460-
Vtype vx6 = vec_xxpermdi(vx, vx1, 3);
461-
Vtype vx7 = vec_xxpermdi(vx2, vx3, 3);
462-
vec_t vx0 =
463-
AIsSigned ? reinterpret_cast<vec_t>(vx4) :
464-
reinterpret_cast<vec_t>(vec_sub(vx4, vmask));
465-
*reinterpret_cast<vec_t *>(&D[0]) = vx0;
466-
vsum = vec_sum4s(vx0, vsum);
467-
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx5) :
468-
reinterpret_cast<vec_t>(vec_sub(vx5, vmask));
469-
*reinterpret_cast<vec_t *>(&D[16]) = vx0;
470-
vsum = vec_sum4s(vx0, vsum);
471-
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx6) :
472-
reinterpret_cast<vec_t>(vec_sub(vx6, vmask));
473-
*reinterpret_cast<vec_t *>(&D[32]) = vx0;
474-
vsum = vec_sum4s(vx0, vsum);
475-
vx0 = AIsSigned ? reinterpret_cast<vec_t>(vx7) :
476-
reinterpret_cast<vec_t>(vec_sub(vx7, vmask));
477-
*reinterpret_cast<vec_t *>(&D[48]) = vx0;
478-
vsum = vec_sum4s(vx0, vsum);
479-
D += 16 * 4;
480479
a += 16;
481480
y -= 16;
482481
}
482+
if (CountM == 1) {
483+
vsum[0] += (vsum[1] + vsum[2] + vsum[3]);
484+
}
483485
while (y >= 4)
484486
{
485487
Vtype vb = vmask;
@@ -496,7 +498,11 @@ MlasGemmQuantCopyPackA8x8(
496498
reinterpret_cast<vec_t>(vec_sub(reinterpret_cast<Vtype>(vx1), vmask));
497499
*reinterpret_cast<vec_t *>(&D[0]) = vx;
498500
vsum = vec_sum4s(vx, vsum);
499-
D += 16;
501+
if (CountM == 1) {
502+
D += 4;
503+
}
504+
else
505+
D += 16;
500506
a += 4;
501507
y -= 4;
502508
}
@@ -1059,6 +1065,186 @@ MlasQgemmComputeMMA(
10591065
}
10601066
}
10611067
};
1068+
1069+
MLAS_FORCEINLINE
1070+
void
1071+
MlasGemmQuantKernel_M1(
1072+
const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType *A,
1073+
const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType *B,
1074+
int32_t *C,
1075+
size_t PackedCountK,
1076+
size_t CountN,
1077+
size_t ldc,
1078+
const int32_t *RowSumBuffer,
1079+
const int32_t *ColumnSumBuffer,
1080+
const int32_t *ZeroPointB,
1081+
bool ZeroMode
1082+
)
1083+
{
1084+
size_t Mval = 1;
1085+
while (CountN > 0) {
1086+
const int8_t *a = A;
1087+
typedef __vector unsigned char vec_t;
1088+
typedef __vector signed char svec_t;
1089+
const uint8_t *b = B;
1090+
MLAS_INT32X4 result = {0};
1091+
__vector signed int VecC = {0, 0, 0, 0};
1092+
__vector signed int VecC2 = {0, 0, 0, 0};
1093+
__vector signed int VecC3 = {0, 0, 0, 0};
1094+
__vector signed int VecC4 = {0, 0, 0, 0};
1095+
size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK;
1096+
size_t k1 = PackedCountK;
1097+
__vector unsigned char va[4];
1098+
__vector unsigned char pat = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3};
1099+
__vector unsigned char pat2 = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7};
1100+
__vector unsigned char pat3 = {8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11};
1101+
__vector unsigned char pat4 = {12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15};
1102+
while (k >= 16) {
1103+
const vec_t *vecA = reinterpret_cast<const vec_t *>(a);
1104+
const vec_t *vb = reinterpret_cast<const vec_t *>(b);
1105+
va[0] = vec_perm(vecA[0], vecA[0], pat);
1106+
va[1] = vec_perm(vecA[0], vecA[0], pat2);
1107+
va[2] = vec_perm(vecA[0], vecA[0], pat3);
1108+
va[3] = vec_perm(vecA[0], vecA[0], pat4);
1109+
VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC);
1110+
VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC);
1111+
VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC);
1112+
VecC = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC);
1113+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 16]);
1114+
VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2);
1115+
VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2);
1116+
VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2);
1117+
VecC2 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC2);
1118+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 32]);
1119+
VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3);
1120+
VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3);
1121+
VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3);
1122+
VecC3 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC3);
1123+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 48]);
1124+
VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4);
1125+
VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4);
1126+
VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4);
1127+
VecC4 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC4);
1128+
b += 64;
1129+
a += 16;
1130+
k -= 16;
1131+
}
1132+
if (k >= 12) {
1133+
const vec_t *vecA = reinterpret_cast<const vec_t *>(a);
1134+
const vec_t *vb = reinterpret_cast<const vec_t *>(b);
1135+
va[0] = vec_perm(vecA[0], vecA[0], pat);
1136+
va[1] = vec_perm(vecA[0], vecA[0], pat2);
1137+
va[2] = vec_perm(vecA[0], vecA[0], pat3);
1138+
VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC);
1139+
VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC);
1140+
VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC);
1141+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 16]);
1142+
VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2);
1143+
VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2);
1144+
VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2);
1145+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 32]);
1146+
VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3);
1147+
VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3);
1148+
VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3);
1149+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 48]);
1150+
VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4);
1151+
VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4);
1152+
VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4);
1153+
a += 12;
1154+
b += 48;
1155+
k -= 12;
1156+
}
1157+
if (k >= 8) {
1158+
const vec_t *vecA = reinterpret_cast<const vec_t *>(a);
1159+
const vec_t *vb = reinterpret_cast<const vec_t *>(b);
1160+
va[0] = vec_perm(vecA[0], vecA[0], pat);
1161+
va[1] = vec_perm(vecA[0], vecA[0], pat2);
1162+
VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC);
1163+
VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC);
1164+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 16]);
1165+
VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2);
1166+
VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2);
1167+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 32]);
1168+
VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3);
1169+
VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3);
1170+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 48]);
1171+
VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4);
1172+
VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4);
1173+
a += 8;
1174+
b += 32;
1175+
k -= 8;
1176+
}
1177+
if (k >= 4) {
1178+
const vec_t *vecA = reinterpret_cast<const vec_t *>(a);
1179+
const vec_t *vb = reinterpret_cast<const vec_t *>(b);
1180+
va[0] = vec_perm(vecA[0], vecA[0], pat);
1181+
VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC);
1182+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 16]);
1183+
VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2);
1184+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 32]);
1185+
VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3);
1186+
vb = reinterpret_cast<const vec_t *>(&b[k1 * 48]);
1187+
VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4);
1188+
a += 4;
1189+
b += 16;
1190+
k -= 4;
1191+
}
1192+
if (CountN >= 16) {
1193+
MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0);
1194+
MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4);
1195+
MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8);
1196+
MlasQgemmStoreVectorMMA<12>(&VecC4, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12);
1197+
INC_BUFFER(16);
1198+
CountN -= 16;
1199+
B += 16 * 4 * PackedCountK;
1200+
C += 16;
1201+
} else {
1202+
if (CountN >= 12) {
1203+
MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0);
1204+
MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4);
1205+
MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8);
1206+
INC_BUFFER(12);
1207+
if (CountN - 12 > 0)
1208+
result = VecC4;
1209+
CountN -= 12;
1210+
C += 12;
1211+
} else if (CountN >= 8) {
1212+
MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0);
1213+
MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4);
1214+
INC_BUFFER(8);
1215+
if (CountN - 8 > 0)
1216+
result = VecC3;
1217+
CountN -= 8;
1218+
C += 8;
1219+
} else if (CountN >= 4) {
1220+
MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0);
1221+
INC_BUFFER(4);
1222+
if (CountN - 4 > 0)
1223+
result = VecC2;
1224+
CountN -= 4;
1225+
C += 4;
1226+
} else
1227+
result = VecC;
1228+
CountN &= 3;
1229+
1230+
// Output the remaining partial output block.
1231+
if (CountN > 0) {
1232+
MlasQgemmStoreScalarMMA<0>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1233+
INC_BUFFER(1);
1234+
}
1235+
if (CountN >= 2) {
1236+
MlasQgemmStoreScalarMMA<1>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1237+
INC_BUFFER(1);
1238+
}
1239+
if (CountN >= 3) {
1240+
MlasQgemmStoreScalarMMA<2>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1241+
INC_BUFFER(1);
1242+
}
1243+
CountN = 0;
1244+
}
1245+
}
1246+
}
1247+
10621248
template<>
10631249
size_t
10641250
MlasGemmQuantKernel<MLAS_GEMM_QUANT_KERNEL_POWER10>(
@@ -1075,6 +1261,10 @@ MlasGemmQuantKernel<MLAS_GEMM_QUANT_KERNEL_POWER10>(
10751261
bool ZeroMode
10761262
)
10771263
{
1264+
if (CountM == 1) {
1265+
MlasGemmQuantKernel_M1(A, B, C, PackedCountK, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode);
1266+
return 1;
1267+
}
10781268
if (CountM < 8 && CountM >= 4) {
10791269
CountM = 4;
10801270
}

0 commit comments

Comments
 (0)