Skip to content

Commit a9e087d

Browse files
authored
[μKernels]: lowering support for i8 type on ARL (#1073)
This patch supports lowering of `i8` using `Int8DotOp (VPDPBSSD)` on `ARL` kind of machine.
1 parent ea2e7f0 commit a9e087d

File tree

10 files changed

+277
-47
lines changed

10 files changed

+277
-47
lines changed

lib/TPP/Transforms/VectorContractToMicroKernels.cpp

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
6060
AffineMap mapB = contractMaps[1];
6161

6262
bool isF32 = elementType.isF32();
63-
bool isF16_BF16 = (elementType.isF16() || elementType.isBF16());
63+
bool isPackedType = (elementType.isF16() || elementType.isBF16() ||
64+
elementType.isSignlessInteger(8));
6465

6566
auto resultsMapA = mapA.getNumResults();
6667
auto resultsMapB = mapB.getNumResults();
@@ -70,7 +71,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
7071
"Result dim map for A and B should be 3");
7172
}
7273

73-
if (isF16_BF16) {
74+
if (isPackedType) {
7475
assert(resultsMapA == 4 && resultsMapB == 4 &&
7576
"Result dim map for A and B should be 4");
7677
}
@@ -83,7 +84,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
8384
"Input dim map for A and B should be 4");
8485
}
8586

86-
if (isF16_BF16) {
87+
if (isPackedType) {
8788
assert(inputsMapA == 5 && inputsMapB == 5 &&
8889
"Input dim map for A and B should be 5");
8990
}
@@ -95,7 +96,7 @@ static bool isTransposedMatrix(vector::ContractionOp contractOp,
9596
auto affineExpr =
9697
dyn_cast<AffineDimExpr>(mlir::getAffineDimExpr(i, mapA.getContext()));
9798

98-
if (isF16_BF16) {
99+
if (isPackedType) {
99100
auto vnniDim = dyn_cast<AffineDimExpr>(mapA.getResult(3));
100101
if (affineExpr != vnniDim && affineExpr != dimBR)
101102
listMxNxK.push_back(affineExpr);
@@ -129,7 +130,8 @@ static bool permutationCheck(vector::ContractionOp contractOp,
129130
AffineMap mapB = contractMaps[1];
130131

131132
bool isF32 = elementType.isF32();
132-
bool isF16_BF16 = (elementType.isF16() || elementType.isBF16());
133+
bool isPackedType = (elementType.isF16() || elementType.isBF16() ||
134+
elementType.isSignlessInteger(8));
133135

134136
auto inputsMapA = mapA.getNumInputs();
135137
SmallVector<AffineDimExpr> inputDims;
@@ -148,7 +150,7 @@ static bool permutationCheck(vector::ContractionOp contractOp,
148150
outputDimsA.push_back(affineExpr);
149151
}
150152

151-
if (isF16_BF16) {
153+
if (isPackedType) {
152154
// We match the pattern {Batch-reduction, vnni, M, N, K} or
153155
// {Batch-reduction, M, N, K, vnni} -> {Batch-reduction, M, K, vnni}
154156
auto c1 = inputDims[0] == outputDimsA[0];
@@ -178,7 +180,7 @@ static bool permutationCheck(vector::ContractionOp contractOp,
178180
outputDimsB.push_back(affineExpr);
179181
}
180182

181-
if (isF16_BF16) {
183+
if (isPackedType) {
182184
// We match the pattern {Batch-reduction, vnni, M, N, K} or
183185
// {Batch-reduction, M, N, K, vnni} -> {Batch-reduction, K, N, vnni}
184186
auto c4 = inputDims[0] == outputDimsB[0];
@@ -290,16 +292,20 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
290292
bool isF32 = elementType.isF32();
291293
bool isF16 = elementType.isF16();
292294
bool isBF16 = elementType.isBF16();
295+
bool isI8 = elementType.isSignlessInteger(8);
293296

294-
if (!(isF32 || isF16 || isBF16))
295-
return rewriter.notifyMatchFailure(contractOp,
296-
"The type is not F32 or F16 or BF16");
297+
bool isPackedType = isF16 || isBF16 || isI8;
298+
int64_t vnniFactor = (isBF16 || isF16) ? 2 : isI8 ? 4 : 0;
299+
300+
if (!(isF32 || isPackedType))
301+
return rewriter.notifyMatchFailure(
302+
contractOp, "The type is not F32 or F16 or BF16 or I8");
297303

298304
bool bf16dp = false;
299305
bool srf = false;
300306
bool fallback = false;
301307

302-
if (isBF16 || isF16) {
308+
if (isPackedType) {
303309
auto cpuName = vnni::utils::getTargetArchName();
304310
if (cpuName == "SRF")
305311
srf = true;
@@ -311,9 +317,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
311317
fallback = true;
312318
}
313319

314-
if (isF16 && !(srf))
320+
if ((isF16 || isI8) && !(srf))
315321
return rewriter.notifyMatchFailure(
316-
contractOp, "F16 type is supported only for SRF kind of machines");
322+
contractOp, "F16/I8 type is supported only for SRF kind of machines");
317323

318324
// Check the operation type MatMul, B-MatMul, or BR-MatMul
319325
SmallVector<vector::IteratorType> contractIteratorTypes =
@@ -328,7 +334,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
328334
return rewriter.notifyMatchFailure(
329335
contractOp, "Batch matmul operation not supported yet");
330336

331-
if (isBF16 || isF16) {
337+
if (isPackedType) {
332338
if (reductionCount == 2)
333339
return rewriter.notifyMatchFailure(
334340
contractOp, "Batch reduce matmul operation without vnni layout");
@@ -360,14 +366,11 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
360366
int64_t K = 0;
361367
int64_t vnni = 0;
362368

363-
if (isBF16 || isF16) {
369+
if (isPackedType) {
364370
M = lhsType.getDimSize(lhsType.getRank() - 3);
365371
N = rhsType.getDimSize(lhsType.getRank() - 2);
366372
K = lhsType.getDimSize(lhsType.getRank() - 2);
367373
vnni = lhsType.getDimSize(lhsType.getRank() - 1);
368-
if (K != (vnni / 2))
369-
return rewriter.notifyMatchFailure(
370-
contractOp, "K tile size should be equal to VNNI layout");
371374

372375
// TODO: We need the N tile size to be divisible by 16 for avx2
373376
// fallback case. So that it ensures, LLVM find a pattern and lowers to
@@ -376,9 +379,17 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
376379
return rewriter.notifyMatchFailure(
377380
contractOp, "N tile size divisible by 16 are only supported");
378381

379-
if (vnni != 2)
382+
if (vnni != 2 && isBF16)
383+
return rewriter.notifyMatchFailure(
384+
contractOp, "Only VNNI layout=2 is supported for bf16, now");
385+
386+
if (vnni != 4 && isI8)
380387
return rewriter.notifyMatchFailure(
381-
contractOp, "Only VNNI layout=2 is supported, now");
388+
contractOp, "Only VNNI layout=4 is supported for i8, now");
389+
390+
if (K != (vnni / vnniFactor))
391+
return rewriter.notifyMatchFailure(
392+
contractOp, "K tile size should be equal to VNNI layout");
382393
}
383394

384395
if (isF32) {
@@ -412,8 +423,8 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
412423
// matrix then broadcast A ony-by-one + FMA.
413424
// If N > M: perform opposite. Broadcast A matrix then load B one-by-
414425
// one + FMA.
415-
// Following this kind of lowering, we reduce the register loads by
416-
// stacking the less B loads or less A broadcasts and do the larger B
426+
// Following this kind of lowering, we reduce the register loads by
427+
// stacking the less B loads or less A broadcasts and do the larger B
417428
// loads or A broadcast in a LIFO manner. Finally, it helps in reducing
418429
// the probablity of register spills.
419430
bool mDriven = true;
@@ -491,7 +502,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
491502
}
492503
}
493504

494-
if (outsElementType.isF32()) {
505+
if (outsElementType.isF32() || outsElementType.isSignlessInteger(32)) {
495506
for (int j = 0; j < N; j = j + sizeFactor) {
496507
for (int i = 0; i < M; i++) {
497508
Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(
@@ -562,12 +573,22 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
562573
auto i1Mask_2 = rewriter.create<arith::ConstantOp>(
563574
kForOp.getLoc(), VectorType::get(2, rewriter.getI1Type()),
564575
boolAttr_2);
565-
auto zeroAttr = rewriter.getFloatAttr(elementType, 0.0);
576+
577+
// ZeroAttr is not needed for i8 type lowering on ARL machine,
578+
// may be need in future for lowering on other machine.
579+
FloatAttr zeroAttr;
580+
if (!isI8) {
581+
zeroAttr = rewriter.getFloatAttr(elementType, 0.0);
582+
}
566583

567584
// Destination type
568585
mlir::VectorType dstType =
569586
mlir::VectorType::get(sizeFactor, rewriter.getF32Type());
570587

588+
if (isI8)
589+
dstType =
590+
mlir::VectorType::get(sizeFactor, rewriter.getI32Type());
591+
571592
llvm::SmallVector<OpFoldResult> strides = {
572593
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
573594
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -664,15 +685,16 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
664685

665686
// bf16 type + avx512. uKernel lowering for machines like
666687
// cpx (zen5) to target avx512bf16dp.
667-
if (bf16dp && isBF16) {
688+
if (bf16dp || isI8) {
668689

669690
if (mDriven) { // M -> N
670691
// Load elements of B matrix and store in a DS
671692
for (int j = 0; j < N; j = j + sizeFactor) {
672693
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
673694
reductionForOp.getLoc(), j);
674695
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
675-
kForOp.getLoc(), VectorType::get(32, elementType),
696+
kForOp.getLoc(),
697+
VectorType::get({sizeFactor * vnni}, elementType),
676698
rhsClone->getResult(0),
677699
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
678700
indexOp_c0});
@@ -700,15 +722,27 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
700722
auto valuef32 =
701723
rewriterNewKForOp.create<vector::BitCastOp>(
702724
kForOp.getLoc(),
703-
VectorType::get(32,
704-
rewriterNewKForOp.getBF16Type()),
725+
VectorType::get({sizeFactor * vnni}, elementType),
705726
bcst_i32);
706-
for (int j = 0; j < (N / sizeFactor); j++) {
707-
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
708-
kForOp.getLoc(), dstType,
709-
iterArgsNewKForOp[i + (j * M)], valuef32,
710-
matf32[j]);
711-
oddFMAs.push_back(dp);
727+
728+
if (isBF16) {
729+
for (int j = 0; j < (N / sizeFactor); j++) {
730+
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
731+
kForOp.getLoc(), dstType,
732+
iterArgsNewKForOp[i + (j * M)], valuef32,
733+
matf32[j]);
734+
oddFMAs.push_back(dp);
735+
}
736+
}
737+
738+
if (isI8) {
739+
for (int j = 0; j < (N / sizeFactor); j++) {
740+
auto dp = rewriter.create<mlir::x86vector::DotInt8Op>(
741+
kForOp.getLoc(), dstType,
742+
iterArgsNewKForOp[i + (j * M)], valuef32,
743+
matf32[j]);
744+
oddFMAs.push_back(dp);
745+
}
712746
}
713747
}
714748

@@ -743,8 +777,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
743777
auto valuef32 =
744778
rewriterNewKForOp.create<vector::BitCastOp>(
745779
kForOp.getLoc(),
746-
VectorType::get(32,
747-
rewriterNewKForOp.getBF16Type()),
780+
VectorType::get({sizeFactor * vnni}, elementType),
748781
bcst_i32);
749782
matf32.push_back(valuef32);
750783
}
@@ -753,16 +786,30 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
753786
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
754787
reductionForOp.getLoc(), j);
755788
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
756-
kForOp.getLoc(), VectorType::get(32, elementType),
789+
kForOp.getLoc(),
790+
VectorType::get({sizeFactor * vnni}, elementType),
757791
rhsClone->getResult(0),
758792
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
759793
indexOp_c0});
760-
for (int i = 0; i < M; i++) {
761-
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
762-
kForOp.getLoc(), dstType, iterArgsNewKForOp[k],
763-
matf32[i], valueRow);
764-
k++;
765-
evenFMAs.push_back(dp);
794+
795+
if (isBF16) {
796+
for (int i = 0; i < M; i++) {
797+
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
798+
kForOp.getLoc(), dstType, iterArgsNewKForOp[k],
799+
matf32[i], valueRow);
800+
k++;
801+
evenFMAs.push_back(dp);
802+
}
803+
}
804+
805+
if (isI8) {
806+
for (int i = 0; i < M; i++) {
807+
auto dp = rewriter.create<mlir::x86vector::DotInt8Op>(
808+
kForOp.getLoc(), dstType, iterArgsNewKForOp[k],
809+
matf32[i], valueRow);
810+
k++;
811+
evenFMAs.push_back(dp);
812+
}
766813
}
767814
}
768815
}
@@ -905,7 +952,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
905952
// (b) bf16 fallback + avx2 instructions.
906953
// TODO: update lowering based on M & N. Now it is
907954
// default to M -> N
908-
if (srf || (fallback && avx2 && !avx512)) {
955+
if ((srf && !isI8) || (fallback && avx2 && !avx512)) {
909956
// Load odd elements of A Matrix and store in a DS
910957
for (int i = 0; i < M; i++) {
911958
Value oddA;
@@ -1228,7 +1275,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
12281275

12291276
// get the 2nd input source for addOp via vector transfer read
12301277
// ps: the 1st one is C matrix
1231-
if (addOp && maxOp && !isF32) {
1278+
if (addOp && maxOp && !isF32 && !isI8) {
12321279
vector::TransferReadOp readOp_add;
12331280
if (auto vectBcst = addOp.getLhs().getDefiningOp<vector::BroadcastOp>()) {
12341281
if (auto vectorRead =
@@ -1268,7 +1315,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
12681315
auto acc_value = newReductionForOp.getResult(k);
12691316
k++;
12701317

1271-
if (addOp && maxOp && !isF32) {
1318+
if (addOp && maxOp && !isF32 && !isI8) {
12721319
Value add_row;
12731320

12741321
if (global_readOp) {
@@ -1360,7 +1407,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
13601407
}
13611408

13621409
// We do arith.tuncf for f32 -> bf16 in SRF/ARL/SPR kind of machines
1363-
if ((srf || bf16dp) && !outsElementType.isF32()) {
1410+
if ((srf || bf16dp) && !outsElementType.isF32() && !isI8) {
13641411
vec_final = rewriter.create<arith::TruncFOp>(
13651412
reductionForOp.getLoc(), VectorType::get(sizeFactor, type),
13661413
acc_value);

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ set(TPP_OPT_TEST_DEPENDS
1414
tpp-run
1515
tpp-sched
1616
fpcmp
17+
check-cpuid
1718
)
1819

1920
add_lit_testsuite(check-tpp "Running the regression tests"

test/I8/Integration/lit.local.cfg

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import subprocess
3+
4+
exec = getattr(config, "cpuid_checker", None)
5+
6+
def is_vpdpbssd_supported():
7+
if not exec or not os.path.exists(exec):
8+
return False
9+
try:
10+
result = subprocess.run([exec], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
11+
return result.returncode == 1
12+
except Exception as e:
13+
return False
14+
15+
def is_arch(target):
16+
# Arch detection not working on Windows
17+
if sys.platform in ['win32']:
18+
return False
19+
20+
try:
21+
cmd = subprocess.Popen(
22+
['uname', '-m'], stdout=subprocess.PIPE)
23+
except OSError:
24+
return False
25+
26+
out = cmd.stdout.read().decode('ascii')
27+
cmd.wait()
28+
29+
return target in out
30+
31+
32+
# Should skip the machine that has no vpdpbssd instruction support
33+
if not is_vpdpbssd_supported():
34+
config.unsupported = True
35+
36+
# Enable only on x86
37+
# Other targets may use different VNNI blocking scheme that is not compatible with
38+
# prepacked shapes in some of the tests
39+
if not is_arch('x86'):
40+
config.unsupported = True
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: tpp-run -e gemm_i8 --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 %s > %t.1
2+
// RUN: tpp-run -e gemm_i8 --entry-point-result=void --vector-to-kernels --registerBlocking=3,32,4 -print --splat-to-random --init-type normal -seed 123 %s > %t.2
3+
// RUN: fpcmp -r 0.001 %t.1 %t.2
4+
5+
func.func @gemm_i8(%arg0: memref<2x24x8x4xi8>, %arg1: memref<2x8x128x4xi8>, %arg2: memref<24x128xi32>) -> memref<24x128xi32> {
6+
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x24x8x4xi8>, memref<2x8x128x4xi8>) outs(%arg2 : memref<24x128xi32>) {
7+
^bb0(%in: i8, %in_1: i8, %out: i32):
8+
%0 = arith.extsi %in : i8 to i32
9+
%1 = arith.extsi %in_1 : i8 to i32
10+
%2 = arith.muli %0, %1 : i32
11+
%3 = arith.addi %out, %2 : i32
12+
linalg.yield %3 : i32
13+
}
14+
return %arg2 : memref<24x128xi32>
15+
}

0 commit comments

Comments
 (0)