Skip to content

Commit 4d3c498

Browse files
authored
[AMD][gfx12] WMMA AMD16x16x32 support for i4 operands (#7012)
# New contributor declaration - [x ] I am not making a trivial change, such as fixing a typo in a comment. - [ x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [ x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) # PR description Previous WMMA support was hardcoded for 16x16x16 tiles, which matched older AMD GPU capabilities. Starting with gfx1200, AMD supports 16x32 input for matrix A and 32x16 for matrix B (for i4 types). To support this, we introduce a mapping from the dot operation's configuration (i.e., shape and element type information) to the corresponding WMMA instruction. This abstraction allows the backend to dynamically determine the key instruction parameters—kDim and kWidth—which are exactly what's needed to enable support for varying K dimensions in WMMA instructions.
1 parent 5c63c72 commit 4d3c498

File tree

9 files changed

+387
-46
lines changed

9 files changed

+387
-46
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,11 +1143,10 @@ Row |
11431143
let hasCustomAssemblyFormat = 1;
11441144

11451145
let extraClassDeclaration = extraDistributedDeclaration # [{
1146-
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1146+
SmallVector<int64_t> getElemsPerInstrForOperands(int kDim, int opIdx) const;
11471147
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1148-
Type elemType, int kWidth, int opIdx) const;
1148+
Type elemType, int kWidth, int kDim, int opIdx) const;
11491149
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
1150-
unsigned getKWidthForOperands() const;
11511150
static SmallVector<unsigned> getMNKDimPerInstr();
11521151
}];
11531152
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,15 +1845,19 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getCTASplitNum() const {
18451845
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
18461846
}
18471847

1848-
SmallVector<int64_t> AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const {
1849-
return {16, 16};
1848+
SmallVector<int64_t>
1849+
AMDWmmaEncodingAttr::getElemsPerInstrForOperands(int kDim, int opIdx) const {
1850+
if (opIdx == 0)
1851+
return {16, kDim};
1852+
else
1853+
return {kDim, 16};
18501854
}
18511855

18521856
SmallVector<int64_t>
18531857
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
1854-
Type elemType, int kWidth,
1858+
Type elemType, int kWidth, int kDim,
18551859
int opIdx) const {
1856-
auto operandTileShape = getElemsPerInstrForOperands();
1860+
auto operandTileShape = getElemsPerInstrForOperands(kDim, opIdx);
18571861
assert(operandTileShape.size() == 2);
18581862
auto warpsPerCTA = getWarpsPerCTA();
18591863
auto rank = operandShape.size();
@@ -1881,14 +1885,6 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
18811885
return {16, 16, 16};
18821886
}
18831887

1884-
unsigned AMDWmmaEncodingAttr::getKWidthForOperands() const {
1885-
SmallVector<unsigned> sizePerThread(getRank(), 1);
1886-
auto numReplicated = getVersion() == 1 ? 2 : 1;
1887-
auto elemsPerInstr =
1888-
numReplicated * product(getElemsPerInstrForOperands()) / 32;
1889-
return elemsPerInstr;
1890-
}
1891-
18921888
//===----------------------------------------------------------------------===//
18931889
// Mma encoding
18941890
//===----------------------------------------------------------------------===//

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9898
tt.return
9999
}
100100

101+
// CHECK-LABEL: wmma2_dot_int8_32
102+
tt.func @wmma2_dot_int8_32(%arg0: tensor<16x32xi4, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 16}>>, %arg1: tensor<32x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma2>) {
103+
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)>
104+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
105+
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
106+
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)>
107+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
108+
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
109+
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
110+
// CHECK: wmma.i32.16x16x32.iu4{{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
111+
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x32xi4, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 16}>> * tensor<32x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 16}>> -> tensor<16x16xi32, #mma2>
112+
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
113+
tt.return
114+
}
115+
101116
// CHECK-LABEL: wmma1_dot_int4_32
102117
tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) {
103118
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)>
@@ -136,6 +151,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
136151
tt.return
137152
}
138153

154+
139155
// CHECK-LABEL: blocked_to_wmma1
140156
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
141157
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_WMMAGROUP_H_
2+
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_WMMAGROUP_H_
3+
4+
#include "mlir/IR/BuiltinTypes.h"
5+
#include "mlir/IR/Types.h"
6+
#include "llvm/ADT/StringRef.h"
7+
8+
namespace mlir {
9+
10+
struct WmmaIntrinsic {
11+
// Chooses a suitable wmma instrinsic for the given input case.
12+
static FailureOr<WmmaIntrinsic> selectFor(int version, unsigned mDim,
13+
unsigned nDim, unsigned inputKDim,
14+
Type aElemType, Type bElemType,
15+
Type dElemType);
16+
17+
WmmaIntrinsic(StringRef symbol, unsigned m, unsigned n, unsigned k,
18+
unsigned kB, Type aET, Type bET, Type dET)
19+
: name(symbol), mDim(m), nDim(n), kDim(k), kBase(kB), aElementType(aET),
20+
bElementType(bET), dElementType(dET) {}
21+
WmmaIntrinsic(const WmmaIntrinsic &other) = default;
22+
WmmaIntrinsic(WmmaIntrinsic &&other) = default;
23+
WmmaIntrinsic() = default;
24+
WmmaIntrinsic &operator=(WmmaIntrinsic &&other) = default;
25+
26+
llvm::StringRef name;
27+
28+
// m, n, and k refer to the shapes of the two operands of an wmma intrinsic:
29+
// Operand A has shape [m]x[k]; operand B has shape [k]x[n].
30+
31+
unsigned mDim;
32+
unsigned nDim;
33+
unsigned kDim;
34+
35+
// kBase is the number of elements each thread holds.
36+
unsigned kBase;
37+
38+
Type aElementType;
39+
Type bElementType;
40+
Type dElementType;
41+
};
42+
} // namespace mlir
43+
44+
#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_WMMAGROUP_H_

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,15 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
164164

165165
auto elemTy = aTensorTy.getElementType();
166166
int kWidth = encoding.getKWidth();
167-
auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands();
167+
168+
int kDim = (wmmaLayout.getVersion() == 2 && kWidth == 16) ? 32 : 16;
169+
auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands(kDim, opIdx);
168170
auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0];
169171
auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1];
170172
assert(wmmaInstrNonK == 16);
171173

172-
auto numReps = wmmaLayout.getRepForOperand(shape, elemTy, kWidth, opIdx);
174+
auto numReps =
175+
wmmaLayout.getRepForOperand(shape, elemTy, kWidth, kDim, opIdx);
173176
auto numRepNonK = numReps[opIdx == 0 ? 1 : 2];
174177
auto numRepK = numReps[opIdx == 0 ? 2 : 1];
175178
auto repB = numReps[0];
@@ -179,7 +182,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
179182
Value waveSize = tb.i32_val(iWaveSize);
180183
Value linearWaveId = tb.udiv(thread, waveSize);
181184

182-
unsigned numElemsPerThreadPerRep = wmmaLayout.getKWidthForOperands();
185+
unsigned numElemsPerThreadPerRep = kWidth;
183186

184187
Value lane = tb.urem(thread, waveSize);
185188
unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK;

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "../PatternTritonGPUOpToLLVM.h"
2525
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
26+
#include "TritonAMDGPUTransforms/WmmaGroup.h"
2627
#include "Utility.h"
2728
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2829
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -211,13 +212,27 @@ StringRef getWmmaIntrinsicName(Type aElTy, Type bElTy, Type dElTy, Type valATy,
211212
return intrinsics[h];
212213
}
213214

215+
std::string addInstructionSuffix(std::string intrinsicName, unsigned kWidth,
216+
Type aElTy, Type bElTy, Type dElTy,
217+
bool tied) {
218+
if (tied) {
219+
intrinsicName += ".tied";
220+
} else {
221+
if (isa<FloatType>(aElTy) && aElTy.getIntOrFloatBitWidth() == 8)
222+
intrinsicName += "." + getTypeStr(bElTy);
223+
intrinsicName += ".v" + std::to_string(kWidth) + getTypeStr(dElTy);
224+
intrinsicName += ".v" + std::to_string(kWidth) + getTypeStr(aElTy);
225+
}
226+
227+
return intrinsicName;
228+
}
229+
214230
Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
215231
Value valA, Value valB, Value valC, Type aElType,
216-
Type bElType, Type dElType,
232+
Type bElType, Type dElType, StringRef name,
217233
std::optional<bool> tiedLower) {
218234
auto b = TritonLLVMOpBuilder(loc, rewriter);
219-
auto name = getWmmaIntrinsicName(aElType, bElType, dElType, valA.getType(),
220-
valC.getType(), tiedLower.has_value());
235+
221236
LLVM::FastmathFlagsAttr defaultFlags{};
222237
SmallVector<Value> operands;
223238
if (aElType.isInteger())
@@ -240,12 +255,12 @@ Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
240255

241256
Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc,
242257
Value valA, Value valB, Value valC, Type aElType,
243-
Type bElType, Type dElType,
258+
Type bElType, Type dElType, StringRef intrinsicName,
244259
std::optional<bool> tiedLower) {
245260
// Independent of wmma version because builtin functions are backward
246261
// compatible
247262
return generateWMMAIntrinsic(rewriter, loc, valA, valB, valC, aElType,
248-
bElType, dElType, tiedLower);
263+
bElType, dElType, intrinsicName, tiedLower);
249264
}
250265

251266
// Conduct the Dot conversion.
@@ -266,16 +281,33 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
266281
auto aTensorTy = cast<RankedTensorType>(a.getType());
267282
auto bTensorTy = cast<RankedTensorType>(b.getType());
268283
auto dTensorTy = cast<RankedTensorType>(d.getType());
269-
auto elemTy = aTensorTy.getElementType();
284+
auto aElemTy = aTensorTy.getElementType();
285+
auto bElemTy = bTensorTy.getElementType();
286+
auto dElemTy = dTensorTy.getElementType();
287+
288+
const auto kDimOperandSize = aTensorTy.getShape().back();
289+
290+
std::string intrinsicName;
291+
FailureOr<WmmaIntrinsic> maybeWmmaIntrinsic =
292+
WmmaIntrinsic::selectFor(wmmaVer, mnkDim[0], mnkDim[1], kDimOperandSize,
293+
aElemTy, bElemTy, dElemTy);
294+
if (failed(maybeWmmaIntrinsic)) {
295+
296+
return op.emitError(
297+
"no matching matrix core intrinsic due to unsupported element type");
298+
}
299+
300+
unsigned kDim = maybeWmmaIntrinsic->kDim;
270301

271302
auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
272303
auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
273304
int kWidth = aEncoding.getKWidth();
305+
intrinsicName = maybeWmmaIntrinsic->name;
274306

275-
auto repA =
276-
wmmaLayout.getRepForOperand(aTensorTy.getShape(), elemTy, kWidth, 0);
277-
auto repB =
278-
wmmaLayout.getRepForOperand(bTensorTy.getShape(), elemTy, kWidth, 1);
307+
auto repA = wmmaLayout.getRepForOperand(aTensorTy.getShape(), aElemTy, kWidth,
308+
kDim, 0);
309+
auto repB = wmmaLayout.getRepForOperand(bTensorTy.getShape(), bElemTy, kWidth,
310+
kDim, 1);
279311

280312
assert(repA[2] == repB[1]);
281313

@@ -307,6 +339,9 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
307339
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
308340
bool tied = numRepM % 2 == 0 && paddedOutputElemSize == 2;
309341
int tiedGroup = tied ? 2 : 1;
342+
343+
intrinsicName = addInstructionSuffix(intrinsicName, kWidth, aElemTy, bElemTy,
344+
dElemTy, tied);
310345
for (int b = 0; b < numRepB; ++b) {
311346
for (int m = 0; m < numRepM / tiedGroup; ++m) {
312347
for (int n = 0; n < numRepN; ++n) {
@@ -334,11 +369,12 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
334369
ha[{b, m * tiedGroup + subTied, k}], acc,
335370
bTensorTy.getElementType(),
336371
aTensorTy.getElementType(), dstElemTy,
337-
optTied)
372+
intrinsicName, optTied)
338373
: generateWMMAOp(
339374
rewriter, loc, ha[{b, m * tiedGroup + subTied, k}],
340375
hb[{b, n, k}], acc, aTensorTy.getElementType(),
341-
bTensorTy.getElementType(), dstElemTy, optTied);
376+
bTensorTy.getElementType(), dstElemTy,
377+
intrinsicName, optTied);
342378
}
343379
}
344380
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
@@ -360,7 +396,9 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
360396
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
361397

362398
const size_t mmaCount = numRepB * numRepM * numRepN * numRepK;
363-
setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy);
399+
setNumGeneratedMMAs(op, mmaCount, maybeWmmaIntrinsic->mDim,
400+
maybeWmmaIntrinsic->nDim, maybeWmmaIntrinsic->kDim,
401+
aElemTy);
364402

365403
rewriter.replaceOp(op, res);
366404
return success();

0 commit comments

Comments
 (0)