@@ -437,49 +437,51 @@ MlasGemmQuantCopyPackA8x8(
437437 Vtype a2 = vmask;
438438 Vtype a3 = vmask;
439439 Vtype a1 = *reinterpret_cast <const Vtype *>(&a[0 ]);
440- if (CountM == 3 ) {
441- a3 = *reinterpret_cast <const Vtype *>(&a[lda * 2 ]);
442- }
443- if (CountM >= 2 ) {
440+ if (CountM == 1 ) {
441+ vec_t va1 = AIsSigned ? reinterpret_cast <vec_t >(a1) : reinterpret_cast <vec_t >(vec_sub (a1, vmask));
442+ *reinterpret_cast <vec_t *>(&D[0 ]) = (vec_t )va1;
443+ vsum = vec_sum4s (va1, vsum);
444+ } else {
444445 a2 = *reinterpret_cast <const Vtype *>(&a[lda]);
446+ if (CountM == 3 ) {
447+ a3 = *reinterpret_cast <const Vtype *>(&a[lda * 2 ]);
448+ }
449+ Vtype vx =
450+ reinterpret_cast <Vtype>(vec_mergee (reinterpret_cast <__vector int >(a1), reinterpret_cast <__vector int >(a2)));
451+ Vtype vx1 =
452+ reinterpret_cast <Vtype>(vec_mergee (reinterpret_cast <__vector int >(a3), reinterpret_cast <__vector int >(a4)));
453+ Vtype vx2 =
454+ reinterpret_cast <Vtype>(vec_mergeo (reinterpret_cast <__vector int >(a1), reinterpret_cast <__vector int >(a2)));
455+ Vtype vx3 =
456+ reinterpret_cast <Vtype>(vec_mergeo (reinterpret_cast <__vector int >(a3), reinterpret_cast <__vector int >(a4)));
457+ Vtype vx4 = vec_xxpermdi (vx, vx1, 0 );
458+ Vtype vx5 = vec_xxpermdi (vx2, vx3, 0 );
459+ Vtype vx6 = vec_xxpermdi (vx, vx1, 3 );
460+ Vtype vx7 = vec_xxpermdi (vx2, vx3, 3 );
461+ vec_t vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx4) : reinterpret_cast <vec_t >(vec_sub (vx4, vmask));
462+ *reinterpret_cast <vec_t *>(&D[0 ]) = vx0;
463+ vsum = vec_sum4s (vx0, vsum);
464+ vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx5) : reinterpret_cast <vec_t >(vec_sub (vx5, vmask));
465+ *reinterpret_cast <vec_t *>(&D[16 ]) = vx0;
466+ vsum = vec_sum4s (vx0, vsum);
467+ vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx6) : reinterpret_cast <vec_t >(vec_sub (vx6, vmask));
468+ *reinterpret_cast <vec_t *>(&D[32 ]) = vx0;
469+ vsum = vec_sum4s (vx0, vsum);
470+ vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx7) : reinterpret_cast <vec_t >(vec_sub (vx7, vmask));
471+ *reinterpret_cast <vec_t *>(&D[48 ]) = vx0;
472+ vsum = vec_sum4s (vx0, vsum);
473+ }
474+ if (CountM == 1 ) {
475+ D += 16 ;
476+ } else {
477+ D += 16 * 4 ;
445478 }
446- Vtype vx =
447- reinterpret_cast <Vtype>(vec_mergee (reinterpret_cast <__vector int >(a1),
448- reinterpret_cast <__vector int >(a2)));
449- Vtype vx1 =
450- reinterpret_cast <Vtype>(vec_mergee (reinterpret_cast <__vector int >(a3),
451- reinterpret_cast <__vector int >(a4)));
452- Vtype vx2 =
453- reinterpret_cast <Vtype>(vec_mergeo (reinterpret_cast <__vector int >(a1),
454- reinterpret_cast <__vector int >(a2)));
455- Vtype vx3 =
456- reinterpret_cast <Vtype>(vec_mergeo (reinterpret_cast <__vector int >(a3),
457- reinterpret_cast <__vector int >(a4)));
458- Vtype vx4 = vec_xxpermdi (vx, vx1, 0 );
459- Vtype vx5 = vec_xxpermdi (vx2, vx3, 0 );
460- Vtype vx6 = vec_xxpermdi (vx, vx1, 3 );
461- Vtype vx7 = vec_xxpermdi (vx2, vx3, 3 );
462- vec_t vx0 =
463- AIsSigned ? reinterpret_cast <vec_t >(vx4) :
464- reinterpret_cast <vec_t >(vec_sub (vx4, vmask));
465- *reinterpret_cast <vec_t *>(&D[0 ]) = vx0;
466- vsum = vec_sum4s (vx0, vsum);
467- vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx5) :
468- reinterpret_cast <vec_t >(vec_sub (vx5, vmask));
469- *reinterpret_cast <vec_t *>(&D[16 ]) = vx0;
470- vsum = vec_sum4s (vx0, vsum);
471- vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx6) :
472- reinterpret_cast <vec_t >(vec_sub (vx6, vmask));
473- *reinterpret_cast <vec_t *>(&D[32 ]) = vx0;
474- vsum = vec_sum4s (vx0, vsum);
475- vx0 = AIsSigned ? reinterpret_cast <vec_t >(vx7) :
476- reinterpret_cast <vec_t >(vec_sub (vx7, vmask));
477- *reinterpret_cast <vec_t *>(&D[48 ]) = vx0;
478- vsum = vec_sum4s (vx0, vsum);
479- D += 16 * 4 ;
480479 a += 16 ;
481480 y -= 16 ;
482481 }
482+ if (CountM == 1 ) {
483+ vsum[0 ] += (vsum[1 ] + vsum[2 ] + vsum[3 ]);
484+ }
483485 while (y >= 4 )
484486 {
485487 Vtype vb = vmask;
@@ -496,7 +498,11 @@ MlasGemmQuantCopyPackA8x8(
496498 reinterpret_cast <vec_t >(vec_sub (reinterpret_cast <Vtype>(vx1), vmask));
497499 *reinterpret_cast <vec_t *>(&D[0 ]) = vx;
498500 vsum = vec_sum4s (vx, vsum);
499- D += 16 ;
501+ if (CountM == 1 ) {
502+ D += 4 ;
503+ }
504+ else
505+ D += 16 ;
500506 a += 4 ;
501507 y -= 4 ;
502508 }
@@ -1059,6 +1065,186 @@ MlasQgemmComputeMMA(
10591065 }
10601066 }
10611067};
1068+
1069+ MLAS_FORCEINLINE
1070+ void
1071+ MlasGemmQuantKernel_M1 (
1072+ const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType *A,
1073+ const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType *B,
1074+ int32_t *C,
1075+ size_t PackedCountK,
1076+ size_t CountN,
1077+ size_t ldc,
1078+ const int32_t *RowSumBuffer,
1079+ const int32_t *ColumnSumBuffer,
1080+ const int32_t *ZeroPointB,
1081+ bool ZeroMode
1082+ )
1083+ {
1084+ size_t Mval = 1 ;
1085+ while (CountN > 0 ) {
1086+ const int8_t *a = A;
1087+ typedef __vector unsigned char vec_t ;
1088+ typedef __vector signed char svec_t ;
1089+ const uint8_t *b = B;
1090+ MLAS_INT32X4 result = {0 };
1091+ __vector signed int VecC = {0 , 0 , 0 , 0 };
1092+ __vector signed int VecC2 = {0 , 0 , 0 , 0 };
1093+ __vector signed int VecC3 = {0 , 0 , 0 , 0 };
1094+ __vector signed int VecC4 = {0 , 0 , 0 , 0 };
1095+ size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK;
1096+ size_t k1 = PackedCountK;
1097+ __vector unsigned char va[4 ];
1098+ __vector unsigned char pat = {0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 };
1099+ __vector unsigned char pat2 = {4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 };
1100+ __vector unsigned char pat3 = {8 , 9 , 10 , 11 , 8 , 9 , 10 , 11 , 8 , 9 , 10 , 11 , 8 , 9 , 10 , 11 };
1101+ __vector unsigned char pat4 = {12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 };
1102+ while (k >= 16 ) {
1103+ const vec_t *vecA = reinterpret_cast <const vec_t *>(a);
1104+ const vec_t *vb = reinterpret_cast <const vec_t *>(b);
1105+ va[0 ] = vec_perm (vecA[0 ], vecA[0 ], pat);
1106+ va[1 ] = vec_perm (vecA[0 ], vecA[0 ], pat2);
1107+ va[2 ] = vec_perm (vecA[0 ], vecA[0 ], pat3);
1108+ va[3 ] = vec_perm (vecA[0 ], vecA[0 ], pat4);
1109+ VecC = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC);
1110+ VecC = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC);
1111+ VecC = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC);
1112+ VecC = vec_msum ((svec_t )va[3 ], (vec_t )vb[3 ], VecC);
1113+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 16 ]);
1114+ VecC2 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC2);
1115+ VecC2 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC2);
1116+ VecC2 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC2);
1117+ VecC2 = vec_msum ((svec_t )va[3 ], (vec_t )vb[3 ], VecC2);
1118+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 32 ]);
1119+ VecC3 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC3);
1120+ VecC3 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC3);
1121+ VecC3 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC3);
1122+ VecC3 = vec_msum ((svec_t )va[3 ], (vec_t )vb[3 ], VecC3);
1123+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 48 ]);
1124+ VecC4 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC4);
1125+ VecC4 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC4);
1126+ VecC4 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC4);
1127+ VecC4 = vec_msum ((svec_t )va[3 ], (vec_t )vb[3 ], VecC4);
1128+ b += 64 ;
1129+ a += 16 ;
1130+ k -= 16 ;
1131+ }
1132+ if (k >= 12 ) {
1133+ const vec_t *vecA = reinterpret_cast <const vec_t *>(a);
1134+ const vec_t *vb = reinterpret_cast <const vec_t *>(b);
1135+ va[0 ] = vec_perm (vecA[0 ], vecA[0 ], pat);
1136+ va[1 ] = vec_perm (vecA[0 ], vecA[0 ], pat2);
1137+ va[2 ] = vec_perm (vecA[0 ], vecA[0 ], pat3);
1138+ VecC = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC);
1139+ VecC = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC);
1140+ VecC = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC);
1141+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 16 ]);
1142+ VecC2 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC2);
1143+ VecC2 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC2);
1144+ VecC2 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC2);
1145+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 32 ]);
1146+ VecC3 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC3);
1147+ VecC3 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC3);
1148+ VecC3 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC3);
1149+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 48 ]);
1150+ VecC4 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC4);
1151+ VecC4 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC4);
1152+ VecC4 = vec_msum ((svec_t )va[2 ], (vec_t )vb[2 ], VecC4);
1153+ a += 12 ;
1154+ b += 48 ;
1155+ k -= 12 ;
1156+ }
1157+ if (k >= 8 ) {
1158+ const vec_t *vecA = reinterpret_cast <const vec_t *>(a);
1159+ const vec_t *vb = reinterpret_cast <const vec_t *>(b);
1160+ va[0 ] = vec_perm (vecA[0 ], vecA[0 ], pat);
1161+ va[1 ] = vec_perm (vecA[0 ], vecA[0 ], pat2);
1162+ VecC = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC);
1163+ VecC = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC);
1164+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 16 ]);
1165+ VecC2 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC2);
1166+ VecC2 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC2);
1167+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 32 ]);
1168+ VecC3 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC3);
1169+ VecC3 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC3);
1170+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 48 ]);
1171+ VecC4 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC4);
1172+ VecC4 = vec_msum ((svec_t )va[1 ], (vec_t )vb[1 ], VecC4);
1173+ a += 8 ;
1174+ b += 32 ;
1175+ k -= 8 ;
1176+ }
1177+ if (k >= 4 ) {
1178+ const vec_t *vecA = reinterpret_cast <const vec_t *>(a);
1179+ const vec_t *vb = reinterpret_cast <const vec_t *>(b);
1180+ va[0 ] = vec_perm (vecA[0 ], vecA[0 ], pat);
1181+ VecC = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC);
1182+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 16 ]);
1183+ VecC2 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC2);
1184+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 32 ]);
1185+ VecC3 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC3);
1186+ vb = reinterpret_cast <const vec_t *>(&b[k1 * 48 ]);
1187+ VecC4 = vec_msum ((svec_t )va[0 ], (vec_t )vb[0 ], VecC4);
1188+ a += 4 ;
1189+ b += 16 ;
1190+ k -= 4 ;
1191+ }
1192+ if (CountN >= 16 ) {
1193+ MlasQgemmStoreVectorMMA<0 >(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0 );
1194+ MlasQgemmStoreVectorMMA<4 >(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4 );
1195+ MlasQgemmStoreVectorMMA<8 >(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8 );
1196+ MlasQgemmStoreVectorMMA<12 >(&VecC4, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12 );
1197+ INC_BUFFER (16 );
1198+ CountN -= 16 ;
1199+ B += 16 * 4 * PackedCountK;
1200+ C += 16 ;
1201+ } else {
1202+ if (CountN >= 12 ) {
1203+ MlasQgemmStoreVectorMMA<0 >(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0 );
1204+ MlasQgemmStoreVectorMMA<4 >(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4 );
1205+ MlasQgemmStoreVectorMMA<8 >(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8 );
1206+ INC_BUFFER (12 );
1207+ if (CountN - 12 > 0 )
1208+ result = VecC4;
1209+ CountN -= 12 ;
1210+ C += 12 ;
1211+ } else if (CountN >= 8 ) {
1212+ MlasQgemmStoreVectorMMA<0 >(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0 );
1213+ MlasQgemmStoreVectorMMA<4 >(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4 );
1214+ INC_BUFFER (8 );
1215+ if (CountN - 8 > 0 )
1216+ result = VecC3;
1217+ CountN -= 8 ;
1218+ C += 8 ;
1219+ } else if (CountN >= 4 ) {
1220+ MlasQgemmStoreVectorMMA<0 >(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0 );
1221+ INC_BUFFER (4 );
1222+ if (CountN - 4 > 0 )
1223+ result = VecC2;
1224+ CountN -= 4 ;
1225+ C += 4 ;
1226+ } else
1227+ result = VecC;
1228+ CountN &= 3 ;
1229+
1230+ // Output the remaining partial output block.
1231+ if (CountN > 0 ) {
1232+ MlasQgemmStoreScalarMMA<0 >(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1233+ INC_BUFFER (1 );
1234+ }
1235+ if (CountN >= 2 ) {
1236+ MlasQgemmStoreScalarMMA<1 >(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1237+ INC_BUFFER (1 );
1238+ }
1239+ if (CountN >= 3 ) {
1240+ MlasQgemmStoreScalarMMA<2 >(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB);
1241+ INC_BUFFER (1 );
1242+ }
1243+ CountN = 0 ;
1244+ }
1245+ }
1246+ }
1247+
10621248template <>
10631249size_t
10641250MlasGemmQuantKernel<MLAS_GEMM_QUANT_KERNEL_POWER10>(
@@ -1075,6 +1261,10 @@ MlasGemmQuantKernel<MLAS_GEMM_QUANT_KERNEL_POWER10>(
10751261 bool ZeroMode
10761262 )
10771263{
1264+ if (CountM == 1 ) {
1265+ MlasGemmQuantKernel_M1 (A, B, C, PackedCountK, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode);
1266+ return 1 ;
1267+ }
10781268 if (CountM < 8 && CountM >= 4 ) {
10791269 CountM = 4 ;
10801270 }
0 commit comments