@@ -60,7 +60,8 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
6060 AffineMap mapB = contractMaps[1 ];
6161
6262 bool isF32 = elementType.isF32 ();
63- bool isF16_BF16 = (elementType.isF16 () || elementType.isBF16 ());
63+ bool isPackedType = (elementType.isF16 () || elementType.isBF16 () ||
64+ elementType.isSignlessInteger (8 ));
6465
6566 auto resultsMapA = mapA.getNumResults ();
6667 auto resultsMapB = mapB.getNumResults ();
@@ -70,7 +71,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
7071 " Result dim map for A and B should be 3" );
7172 }
7273
73- if (isF16_BF16 ) {
74+ if (isPackedType ) {
7475 assert (resultsMapA == 4 && resultsMapB == 4 &&
7576 " Result dim map for A and B should be 4" );
7677 }
@@ -83,7 +84,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
8384 " Input dim map for A and B should be 4" );
8485 }
8586
86- if (isF16_BF16 ) {
87+ if (isPackedType ) {
8788 assert (inputsMapA == 5 && inputsMapB == 5 &&
8889 " Input dim map for A and B should be 5" );
8990 }
@@ -95,7 +96,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
9596 auto affineExpr =
9697 dyn_cast<AffineDimExpr>(mlir::getAffineDimExpr (i, mapA.getContext ()));
9798
98- if (isF16_BF16 ) {
99+ if (isPackedType ) {
99100 auto vnniDim = dyn_cast<AffineDimExpr>(mapA.getResult (3 ));
100101 if (affineExpr != vnniDim && affineExpr != dimBR)
101102 listMxNxK.push_back (affineExpr);
@@ -129,7 +130,8 @@ static bool permutationCheck(vector::ContractionOp contractOp,
129130 AffineMap mapB = contractMaps[1 ];
130131
131132 bool isF32 = elementType.isF32 ();
132- bool isF16_BF16 = (elementType.isF16 () || elementType.isBF16 ());
133+ bool isPackedType = (elementType.isF16 () || elementType.isBF16 () ||
134+ elementType.isSignlessInteger (8 ));
133135
134136 auto inputsMapA = mapA.getNumInputs ();
135137 SmallVector<AffineDimExpr> inputDims;
@@ -148,7 +150,7 @@ static bool permutationCheck(vector::ContractionOp contractOp,
148150 outputDimsA.push_back (affineExpr);
149151 }
150152
151- if (isF16_BF16 ) {
153+ if (isPackedType ) {
152154 // We match the pattern {Batch-reduction, vnni, M, N, K} or
153155 // {Batch-reduction, M, N, K, vnni} -> {Batch-reduction, M, K, vnni}
154156 auto c1 = inputDims[0 ] == outputDimsA[0 ];
@@ -178,7 +180,7 @@ static bool permutationCheck(vector::ContractionOp contractOp,
178180 outputDimsB.push_back (affineExpr);
179181 }
180182
181- if (isF16_BF16 ) {
183+ if (isPackedType ) {
182184 // We match the pattern {Batch-reduction, vnni, M, N, K} or
183185 // {Batch-reduction, M, N, K, vnni} -> {Batch-reduction, K, N, vnni}
184186 auto c4 = inputDims[0 ] == outputDimsB[0 ];
@@ -290,16 +292,20 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
290292 bool isF32 = elementType.isF32 ();
291293 bool isF16 = elementType.isF16 ();
292294 bool isBF16 = elementType.isBF16 ();
295+ bool isI8 = elementType.isSignlessInteger (8 );
293296
294- if (!(isF32 || isF16 || isBF16))
295- return rewriter.notifyMatchFailure (contractOp,
296- " The type is not F32 or F16 or BF16" );
297+ bool isPackedType = isF16 || isBF16 || isI8;
298+ int64_t vnniFactor = (isBF16 || isF16) ? 2 : isI8 ? 4 : 0 ;
299+
300+ if (!(isF32 || isPackedType))
301+ return rewriter.notifyMatchFailure (
302+ contractOp, " The type is not F32 or F16 or BF16 or I8" );
297303
298304 bool bf16dp = false ;
299305 bool srf = false ;
300306 bool fallback = false ;
301307
302- if (isBF16 || isF16 ) {
308+ if (isPackedType ) {
303309 auto cpuName = vnni::utils::getTargetArchName ();
304310 if (cpuName == " SRF" )
305311 srf = true ;
@@ -311,9 +317,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
311317 fallback = true ;
312318 }
313319
314- if (isF16 && !(srf))
320+ if (( isF16 || isI8) && !(srf))
315321 return rewriter.notifyMatchFailure (
316- contractOp, " F16 type is supported only for SRF kind of machines" );
322+ contractOp, " F16/I8 type is supported only for SRF kind of machines" );
317323
318324 // Check the operation type MatMul, B-MatMul, or BR-MatMul
319325 SmallVector<vector::IteratorType> contractIteratorTypes =
@@ -328,7 +334,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
328334 return rewriter.notifyMatchFailure (
329335 contractOp, " Batch matmul operation not supported yet" );
330336
331- if (isBF16 || isF16 ) {
337+ if (isPackedType ) {
332338 if (reductionCount == 2 )
333339 return rewriter.notifyMatchFailure (
334340 contractOp, " Batch reduce matmul operation without vnni layout" );
@@ -360,14 +366,11 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
360366 int64_t K = 0 ;
361367 int64_t vnni = 0 ;
362368
363- if (isBF16 || isF16 ) {
369+ if (isPackedType ) {
364370 M = lhsType.getDimSize (lhsType.getRank () - 3 );
365371 N = rhsType.getDimSize (lhsType.getRank () - 2 );
366372 K = lhsType.getDimSize (lhsType.getRank () - 2 );
367373 vnni = lhsType.getDimSize (lhsType.getRank () - 1 );
368- if (K != (vnni / 2 ))
369- return rewriter.notifyMatchFailure (
370- contractOp, " K tile size should be equal to VNNI layout" );
371374
372375 // TODO: We need the N tile size to be divisible by 16 for avx2
373376 // fallback case. So that it ensures, LLVM find a pattern and lowers to
@@ -376,9 +379,17 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
376379 return rewriter.notifyMatchFailure (
377380 contractOp, " N tile size divisible by 16 are only supported" );
378381
379- if (vnni != 2 )
382+ if (vnni != 2 && isBF16)
383+ return rewriter.notifyMatchFailure (
384+ contractOp, " Only VNNI layout=2 is supported for bf16, now" );
385+
386+ if (vnni != 4 && isI8)
380387 return rewriter.notifyMatchFailure (
381- contractOp, " Only VNNI layout=2 is supported, now" );
388+ contractOp, " Only VNNI layout=4 is supported for i8, now" );
389+
390+ if (K != (vnni / vnniFactor))
391+ return rewriter.notifyMatchFailure (
392+ contractOp, " K tile size should be equal to VNNI layout" );
382393 }
383394
384395 if (isF32) {
@@ -412,8 +423,8 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
412423 // matrix then broadcast A ony-by-one + FMA.
413424 // If N > M: perform opposite. Broadcast A matrix then load B one-by-
414425 // one + FMA.
415- // Following this kind of lowering, we reduce the register loads by
416- // stacking the less B loads or less A broadcasts and do the larger B
426+ // Following this kind of lowering, we reduce the register loads by
427+ // stacking the less B loads or less A broadcasts and do the larger B
417428 // loads or A broadcast in a LIFO manner. Finally, it helps in reducing
418429 // the probablity of register spills.
419430 bool mDriven = true ;
@@ -491,7 +502,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
491502 }
492503 }
493504
494- if (outsElementType.isF32 ()) {
505+ if (outsElementType.isF32 () || outsElementType. isSignlessInteger ( 32 ) ) {
495506 for (int j = 0 ; j < N; j = j + sizeFactor) {
496507 for (int i = 0 ; i < M; i++) {
497508 Value indexOp_A = rewriter.create <arith::ConstantIndexOp>(
@@ -562,12 +573,22 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
562573 auto i1Mask_2 = rewriter.create <arith::ConstantOp>(
563574 kForOp .getLoc (), VectorType::get (2 , rewriter.getI1Type ()),
564575 boolAttr_2);
565- auto zeroAttr = rewriter.getFloatAttr (elementType, 0.0 );
576+
577+ // ZeroAttr is not needed for i8 type lowering on ARL machine,
578+ // may be need in future for lowering on other machine.
579+ FloatAttr zeroAttr;
580+ if (!isI8) {
581+ zeroAttr = rewriter.getFloatAttr (elementType, 0.0 );
582+ }
566583
567584 // Destination type
568585 mlir::VectorType dstType =
569586 mlir::VectorType::get (sizeFactor, rewriter.getF32Type ());
570587
588+ if (isI8)
589+ dstType =
590+ mlir::VectorType::get (sizeFactor, rewriter.getI32Type ());
591+
571592 llvm::SmallVector<OpFoldResult> strides = {
572593 rewriter.getIndexAttr (1 ), rewriter.getIndexAttr (1 ),
573594 rewriter.getIndexAttr (1 ), rewriter.getIndexAttr (1 )};
@@ -664,15 +685,16 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
664685
665686 // bf16 type + avx512. uKernel lowering for machines like
666687 // cpx (zen5) to target avx512bf16dp.
667- if (bf16dp && isBF16 ) {
688+ if (bf16dp || isI8 ) {
668689
669690 if (mDriven ) { // M -> N
670691 // Load elements of B matrix and store in a DS
671692 for (int j = 0 ; j < N; j = j + sizeFactor) {
672693 Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
673694 reductionForOp.getLoc (), j);
674695 auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
675- kForOp .getLoc (), VectorType::get (32 , elementType),
696+ kForOp .getLoc (),
697+ VectorType::get ({sizeFactor * vnni}, elementType),
676698 rhsClone->getResult (0 ),
677699 ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
678700 indexOp_c0});
@@ -700,15 +722,27 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
700722 auto valuef32 =
701723 rewriterNewKForOp.create <vector::BitCastOp>(
702724 kForOp .getLoc (),
703- VectorType::get (32 ,
704- rewriterNewKForOp.getBF16Type ()),
725+ VectorType::get ({sizeFactor * vnni}, elementType),
705726 bcst_i32);
706- for (int j = 0 ; j < (N / sizeFactor); j++) {
707- auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
708- kForOp .getLoc (), dstType,
709- iterArgsNewKForOp[i + (j * M)], valuef32,
710- matf32[j]);
711- oddFMAs.push_back (dp);
727+
728+ if (isBF16) {
729+ for (int j = 0 ; j < (N / sizeFactor); j++) {
730+ auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
731+ kForOp .getLoc (), dstType,
732+ iterArgsNewKForOp[i + (j * M)], valuef32,
733+ matf32[j]);
734+ oddFMAs.push_back (dp);
735+ }
736+ }
737+
738+ if (isI8) {
739+ for (int j = 0 ; j < (N / sizeFactor); j++) {
740+ auto dp = rewriter.create <mlir::x86vector::DotInt8Op>(
741+ kForOp .getLoc (), dstType,
742+ iterArgsNewKForOp[i + (j * M)], valuef32,
743+ matf32[j]);
744+ oddFMAs.push_back (dp);
745+ }
712746 }
713747 }
714748
@@ -743,8 +777,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
743777 auto valuef32 =
744778 rewriterNewKForOp.create <vector::BitCastOp>(
745779 kForOp .getLoc (),
746- VectorType::get (32 ,
747- rewriterNewKForOp.getBF16Type ()),
780+ VectorType::get ({sizeFactor * vnni}, elementType),
748781 bcst_i32);
749782 matf32.push_back (valuef32);
750783 }
@@ -753,16 +786,30 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
753786 Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
754787 reductionForOp.getLoc (), j);
755788 auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
756- kForOp .getLoc (), VectorType::get (32 , elementType),
789+ kForOp .getLoc (),
790+ VectorType::get ({sizeFactor * vnni}, elementType),
757791 rhsClone->getResult (0 ),
758792 ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
759793 indexOp_c0});
760- for (int i = 0 ; i < M; i++) {
761- auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
762- kForOp .getLoc (), dstType, iterArgsNewKForOp[k],
763- matf32[i], valueRow);
764- k++;
765- evenFMAs.push_back (dp);
794+
795+ if (isBF16) {
796+ for (int i = 0 ; i < M; i++) {
797+ auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
798+ kForOp .getLoc (), dstType, iterArgsNewKForOp[k],
799+ matf32[i], valueRow);
800+ k++;
801+ evenFMAs.push_back (dp);
802+ }
803+ }
804+
805+ if (isI8) {
806+ for (int i = 0 ; i < M; i++) {
807+ auto dp = rewriter.create <mlir::x86vector::DotInt8Op>(
808+ kForOp .getLoc (), dstType, iterArgsNewKForOp[k],
809+ matf32[i], valueRow);
810+ k++;
811+ evenFMAs.push_back (dp);
812+ }
766813 }
767814 }
768815 }
@@ -905,7 +952,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
905952 // (b) bf16 fallback + avx2 instructions.
906953 // TODO: update lowering based on M & N. Now it is
907954 // default to M -> N
908- if (srf || (fallback && avx2 && !avx512)) {
955+ if (( srf && !isI8) || (fallback && avx2 && !avx512)) {
909956 // Load odd elements of A Matrix and store in a DS
910957 for (int i = 0 ; i < M; i++) {
911958 Value oddA;
@@ -1228,7 +1275,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
12281275
12291276 // get the 2nd input source for addOp via vector transfer read
12301277 // ps: the 1st one is C matrix
1231- if (addOp && maxOp && !isF32) {
1278+ if (addOp && maxOp && !isF32 && !isI8 ) {
12321279 vector::TransferReadOp readOp_add;
12331280 if (auto vectBcst = addOp.getLhs ().getDefiningOp <vector::BroadcastOp>()) {
12341281 if (auto vectorRead =
@@ -1268,7 +1315,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
12681315 auto acc_value = newReductionForOp.getResult (k);
12691316 k++;
12701317
1271- if (addOp && maxOp && !isF32) {
1318+ if (addOp && maxOp && !isF32 && !isI8 ) {
12721319 Value add_row;
12731320
12741321 if (global_readOp) {
@@ -1360,7 +1407,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
13601407 }
13611408
13621409 // We do arith.tuncf for f32 -> bf16 in SRF/ARL/SPR kind of machines
1363- if ((srf || bf16dp) && !outsElementType.isF32 ()) {
1410+ if ((srf || bf16dp) && !outsElementType.isF32 () && !isI8 ) {
13641411 vec_final = rewriter.create <arith::TruncFOp>(
13651412 reductionForOp.getLoc (), VectorType::get (sizeFactor, type),
13661413 acc_value);
0 commit comments