Skip to content

Commit b962e44

Browse files
authored
[AMD] Support tied wmma instrucrions (#4483)
- Generated intrinsic for wmma calculations - Generate tied instructions along M axis if possible - Supported transposed case - Added lit tests Signed-off-by: Ilya Veselov <[email protected]>
1 parent 484b9c6 commit b962e44

File tree

2 files changed

+92
-46
lines changed

2 files changed

+92
-46
lines changed

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2727
tt.return
2828
}
2929

30-
// CHECK-LABEL: wmma1_dot
31-
tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) {
30+
// CHECK-LABEL: wmma1_dot_f16
31+
tt.func @wmma1_dot_f16(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) {
3232
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
3333
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
3434
// CHECK: llvm.mlir.undef : vector<16xf16>
3535
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16>
36-
// CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
36+
// CHECK: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
3737
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1>
3838
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
3939
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
@@ -50,11 +50,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
5050
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
5151
// CHECK: llvm.mlir.undef : vector<16xbf16>
5252
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16>
53-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
53+
// CHECK: wmma.bf16.16x16x16.bf16{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
5454
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1>
5555
tt.return
5656
}
5757

58+
// CHECK-LABEL: wmma1_dot_f16_tied
59+
tt.func @wmma1_dot_f16_tied(%arg0: tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<64x16xf16, #mma1>) {
60+
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
61+
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
62+
// CHECK: llvm.mlir.undef : vector<16xf16>
63+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xf16>
64+
// CHECK-COUNT-2: wmma.f16.16x16x16.f16.tied{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
65+
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<64x16xf16, #mma1>
66+
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
67+
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
68+
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
69+
tt.return
70+
}
71+
72+
// CHECK-LABEL: wmma1_dot_bf16_tied
73+
tt.func @wmma1_dot_bf16_tied(%arg0: tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<64x16xbf16, #mma1>) {
74+
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
75+
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
76+
// CHECK: llvm.mlir.undef : vector<16xbf16>
77+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xbf16>
78+
// CHECK-COUNT-2: wmma.bf16.16x16x16.bf16.tied{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
79+
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<64x16xbf16, #mma1>
80+
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xbf16>
81+
// CHECK: llvm.mlir.undef : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
82+
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
83+
tt.return
84+
}
85+
5886
// CHECK-LABEL: wmma1_dot_int8_32
5987
tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) {
6088
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
@@ -64,7 +92,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6492
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8>
6593
// CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
6694
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
67-
// CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
95+
// CHECK: wmma.i32.16x16x16.iu8{{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
6896
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
6997
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
7098
tt.return
@@ -79,7 +107,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
79107
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
80108
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
81109
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
82-
// CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
110+
// CHECK: wmma.i32.16x16x16.iu4{{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
83111
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
84112
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
85113
tt.return
@@ -196,7 +224,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
196224
// CHECK-COUNT-32: llvm.insertelement
197225
// CHECK-COUNT-8: llvm.extractvalue %arg2
198226
// CHECK-COUNT-8: llvm.insertelement
199-
// CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
227+
// CHECK-COUNT-2: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
200228
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1>
201229
// CHECK-COUNT-8: llvm.extractelement
202230
// CHECK-COUNT-8: llvm.insertvalue

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -183,33 +183,37 @@ std::string getTypeStr(Type ty) {
183183
}
184184

185185
StringRef getWmmaIntrinsicName(Type aElTy, Type bElTy, Type dElTy, Type valATy,
186-
Type valCTy) {
186+
Type valCTy, bool tied) {
187187
static llvm::SmallDenseMap<llvm::hash_code, std::string> intrinsics;
188188
using MapInfo = llvm::DenseMapInfo<Type>;
189189
llvm::hash_code h = llvm::hash_combine(
190190
MapInfo::getHashValue(aElTy), MapInfo::getHashValue(bElTy),
191191
MapInfo::getHashValue(dElTy), MapInfo::getHashValue(valATy),
192-
MapInfo::getHashValue(valCTy));
192+
MapInfo::getHashValue(valCTy), llvm::hash_value(tied));
193193
if (!intrinsics.contains(h)) {
194194
std::string name = "llvm.amdgcn.wmma.";
195195
name += getTypeStr(dElTy);
196196
name += ".16x16x16."; // TODO support 16x16x32 for i4 operands
197197
name += getTypeStr(aElTy);
198-
if (isa<FloatType>(aElTy) && aElTy.getIntOrFloatBitWidth() == 8)
199-
name += '.' + getTypeStr(bElTy);
200-
name += '.' + getTypeStr(valCTy) + "." + getTypeStr(valATy);
198+
if (tied) {
199+
name += ".tied";
200+
} else {
201+
if (isa<FloatType>(aElTy) && aElTy.getIntOrFloatBitWidth() == 8)
202+
name += '.' + getTypeStr(bElTy);
203+
name += '.' + getTypeStr(valCTy) + "." + getTypeStr(valATy);
204+
}
201205
intrinsics[h] = name;
202206
}
203207
return intrinsics[h];
204208
}
205209

206210
Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
207-
WMMAInstrType wmmaType, Value valA, Value valB,
208-
Value valC, Type aElType, Type bElType,
209-
Type dElType) {
211+
Value valA, Value valB, Value valC, Type aElType,
212+
Type bElType, Type dElType,
213+
std::optional<bool> tiedLower) {
210214
auto b = TritonLLVMOpBuilder(loc, rewriter);
211215
auto name = getWmmaIntrinsicName(aElType, bElType, dElType, valA.getType(),
212-
valC.getType());
216+
valC.getType(), tiedLower.has_value());
213217
LLVM::FastmathFlagsAttr defaultFlags{};
214218
SmallVector<Value> operands;
215219
if (aElType.isInteger())
@@ -221,25 +225,23 @@ Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
221225
operands.push_back(valC);
222226
// Flag for using low bits in registers. Result could be already packed to
223227
// int32. Set low bits by default for now.
224-
if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) {
225-
operands.push_back(b.int_val(1, false));
228+
if (tiedLower.has_value() || 32 / dElType.getIntOrFloatBitWidth() > 1 ||
229+
dElType.isInteger(32)) {
230+
operands.push_back(b.int_val(1, tiedLower.value_or(false)));
226231
}
227232
auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp(
228233
rewriter, loc, name, valC.getType(), operands);
229234
return wmmaIntrinsic.getResult(0);
230235
}
231236

232237
Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc,
233-
WMMAInstrType wmmaType, Value valA, Value valB, Value valC,
234-
Type aElType, Type bElType, Type dElType, int version) {
235-
if (version == 1) {
236-
return generateROCDLOp(rewriter, loc, wmmaType, valA, valB, valC, aElType,
237-
bElType);
238-
} else {
239-
assert(version == 2);
240-
return generateWMMAIntrinsic(rewriter, loc, wmmaType, valA, valB, valC,
241-
aElType, bElType, dElType);
242-
}
238+
Value valA, Value valB, Value valC, Type aElType,
239+
Type bElType, Type dElType,
240+
std::optional<bool> tiedLower) {
241+
// Independent of wmma version because builtin functions are backward
242+
// compatible
243+
return generateWMMAIntrinsic(rewriter, loc, valA, valB, valC, aElType,
244+
bElType, dElType, tiedLower);
243245
}
244246

245247
// Conduct the Dot conversion.
@@ -251,7 +253,6 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
251253
int wmmaVer = wmmaLayout.getVersion();
252254
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
253255
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr();
254-
auto wmmaInstrType = getWMMAInstrTypeFromDot(op);
255256

256257
auto loc = op.getLoc();
257258
auto tb = TritonLLVMOpBuilder(loc, rewriter);
@@ -300,33 +301,50 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
300301
auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize;
301302
auto dElemsToStorePerThread = mnkDim[0] * mnkDim[1] / warpSize;
302303
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
304+
bool tied = numRepM % 2 == 0 && paddedOutputElemSize == 2;
305+
int tiedGroup = tied ? 2 : 1;
303306
for (int b = 0; b < numRepB; ++b) {
304-
for (int m = 0; m < numRepM; ++m) {
307+
for (int m = 0; m < numRepM / tiedGroup; ++m) {
305308
for (int n = 0; n < numRepN; ++n) {
306309
auto batchOffIdx = b * numRepM * numRepN * dElemsToStorePerThread;
307-
auto mRepOffId = m * numRepN * dElemsToStorePerThread;
308310
auto nRepOffId = n * dElemsToStorePerThread;
309-
auto fcThreadOffIdx = batchOffIdx + mRepOffId + nRepOffId;
311+
auto nBatchOffSum = nRepOffId + batchOffIdx;
310312

311313
Value acc = tb.undef(vecTy);
312314
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
313-
acc = tb.insert_element(vecTy, acc, fc[fcThreadOffIdx + v],
314-
tb.i32_val(v * paddedOutputElemSize));
315+
for (int subTied = 0; subTied < tiedGroup; ++subTied) {
316+
auto mRepOffId =
317+
(m * tiedGroup + subTied) * numRepN * dElemsToStorePerThread;
318+
auto fcThreadOffIdx = nBatchOffSum + mRepOffId;
319+
acc = tb.insert_element(
320+
vecTy, acc, fc[fcThreadOffIdx + v],
321+
tb.i32_val(v * paddedOutputElemSize + subTied));
322+
}
315323
}
316-
for (size_t k = 0; k < numRepK; k++) {
317-
acc = wmmaLayout.getIsTransposed()
318-
? generateWMMAOp(
319-
rewriter, loc, wmmaInstrType, hb[{b, n, k}],
320-
ha[{b, m, k}], acc, bTensorTy.getElementType(),
321-
aTensorTy.getElementType(), dstElemTy, wmmaVer)
322-
: generateWMMAOp(
323-
rewriter, loc, wmmaInstrType, ha[{b, m, k}],
324-
hb[{b, n, k}], acc, aTensorTy.getElementType(),
325-
bTensorTy.getElementType(), dstElemTy, wmmaVer);
324+
for (size_t k = 0; k < numRepK; ++k) {
325+
for (int subTied = 0; subTied < tiedGroup; ++subTied) {
326+
auto optTied =
327+
tied ? std::optional<bool>(subTied != 0) : std::nullopt;
328+
acc = wmmaLayout.getIsTransposed()
329+
? generateWMMAOp(rewriter, loc, hb[{b, n, k}],
330+
ha[{b, m * tiedGroup + subTied, k}], acc,
331+
bTensorTy.getElementType(),
332+
aTensorTy.getElementType(), dstElemTy,
333+
optTied)
334+
: generateWMMAOp(
335+
rewriter, loc, ha[{b, m * tiedGroup + subTied, k}],
336+
hb[{b, n, k}], acc, aTensorTy.getElementType(),
337+
bTensorTy.getElementType(), dstElemTy, optTied);
338+
}
326339
}
327340
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
328-
fc[fcThreadOffIdx + v] = tb.extract_element(
329-
dstElemTy, acc, tb.i32_val(v * paddedOutputElemSize));
341+
for (int subTied = 0; subTied < tiedGroup; ++subTied) {
342+
auto mRepOffId =
343+
(m * tiedGroup + subTied) * numRepN * dElemsToStorePerThread;
344+
auto fcThreadOffIdx = nBatchOffSum + mRepOffId;
345+
fc[fcThreadOffIdx + v] = tb.extract_element(
346+
dstElemTy, acc, tb.i32_val(v * paddedOutputElemSize + subTied));
347+
}
330348
}
331349
}
332350
}

0 commit comments

Comments
 (0)