Skip to content

Commit 3c13f09

Browse files
authored
[AMD] NFC: Refactor AccelerateAMDMatmul patterns (#4985)
This commit refactors the AccelerateAMDMatmul patterns in prep for mxfp support.
1 parent 4a54311 commit 3c13f09

File tree

1 file changed

+109
-102
lines changed

1 file changed

+109
-102
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 109 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "mlir/Support/LogicalResult.h"
66
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
77
#include "triton/Analysis/Utility.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
9+
#include "triton/Dialect/Triton/IR/Dialect.h"
810
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
911
#include <memory>
1012

@@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) {
3638
return 0;
3739
}
3840

39-
SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
40-
const ArrayRef<int64_t> shape,
41-
int numWarps,
42-
SmallVector<int64_t, 2> shapePerWarp) {
41+
SmallVector<unsigned, 3>
42+
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
43+
std::pair<int64_t, int64_t> shapePerWarp) {
4344
auto rank = shape.size();
4445
// Early exit for batched matmul
4546
if (rank == 3)
4647
return {(unsigned)numWarps, 1, 1};
4748

48-
auto filter = [&dotOp](Operation *op) {
49+
auto filter = [dotOp](Operation *op) {
4950
return op->getParentRegion() == dotOp->getParentRegion();
5051
};
5152
ForwardSliceOptions fwdOpt;
@@ -55,17 +56,17 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
5556
bwdOpt.filter = filter;
5657
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
5758
for (Operation *op : slices)
58-
if (isa<tt::DotOp>(op) && (op != dotOp))
59+
if (op->hasTrait<OpTrait::DotLike>() && (op != dotOp))
5960
return {(unsigned)numWarps, 1};
6061

6162
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
6263
SmallVector<unsigned, 2> ret = {1, 1};
6364
do {
6465
if (ret[0] * ret[1] >= numWarps)
6566
break;
66-
if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >=
67-
tensorShape[1] / shapePerWarp[1] / ret[1]) {
68-
if (ret[0] < tensorShape[0] / shapePerWarp[0]) {
67+
if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >=
68+
tensorShape[1] / shapePerWarp.second / ret[1]) {
69+
if (ret[0] < tensorShape[0] / shapePerWarp.first) {
6970
ret[0] *= 2;
7071
} else
7172
ret[1] *= 2;
@@ -74,24 +75,89 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
7475
}
7576
} while (true);
7677

77-
if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
78+
if (ret[1] * shapePerWarp.second > tensorShape[1]) {
7879
return {ret[1], ret[0]};
7980
}
8081

8182
return ret;
8283
}
8384

84-
SmallVector<unsigned, 2>
85-
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
86-
SmallVector<int64_t, 2> shapePerWarp) {
85+
SmallVector<unsigned, 3>
86+
warpsPerTileMFMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
87+
std::pair<int64_t, int64_t> shapePerWarp) {
8788
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
8889
}
8990

90-
SmallVector<unsigned, 2>
91-
warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
92-
return warpsPerTile(dotOp, shape, numWarps,
93-
{ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0],
94-
ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]});
91+
SmallVector<unsigned, 3>
92+
warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
93+
auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr();
94+
return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]});
95+
}
96+
97+
// Chooses a proper MFMA instruction that can used to compute the given dot op.
98+
// If enforcedNonKDim is not zero, it will be used to overwrite the default
99+
// logic to chose a MFMA with matching M/N dim.
100+
FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
101+
Type aElemType, Type bElemType,
102+
int inputKSize, int mfmaVersion,
103+
int enforcedNonKDim) {
104+
// number of matrix elements along k dim per one MFMA intruction
105+
unsigned kDim = 0;
106+
107+
auto resShape = cType.getShape();
108+
auto rank = resShape.size();
109+
auto M = resShape[rank - 2];
110+
auto N = resShape[rank - 1];
111+
112+
unsigned mDim = 0;
113+
unsigned nDim = 0;
114+
if (enforcedNonKDim != 0) {
115+
mDim = nDim = enforcedNonKDim;
116+
} else {
117+
int minSize = std::min(M, N);
118+
if (minSize >= 32) {
119+
mDim = 32;
120+
nDim = 32;
121+
}
122+
if (minSize >= 16 && minSize < 32) {
123+
mDim = 16;
124+
nDim = 16;
125+
}
126+
if (minSize < 16) {
127+
if (M < 16 && N >= 64) {
128+
mDim = 4;
129+
nDim = 64;
130+
} else if (M >= 64 && N < 16) {
131+
mDim = 64;
132+
nDim = 4;
133+
} else {
134+
assert(inputKSize >= 64 &&
135+
"k should be at least 64 to use this layout");
136+
mDim = 4;
137+
nDim = 4;
138+
}
139+
}
140+
}
141+
assert(mDim != 0 && nDim != 0);
142+
143+
auto maybeMfmaInsn =
144+
MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion);
145+
if (failed(maybeMfmaInsn))
146+
llvm::report_fatal_error("No match found in MFMA database\n");
147+
148+
kDim = maybeMfmaInsn->getKDim();
149+
assert(kDim != 0);
150+
assert(M % mDim == 0 && N % nDim == 0);
151+
assert(inputKSize % kDim == 0);
152+
return maybeMfmaInsn;
153+
}
154+
155+
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
156+
int nonKDim) {
157+
RankedTensorType aType = dot.getA().getType();
158+
return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(),
159+
dot.getB().getType().getElementType(),
160+
aType.getShape().back(), mfmaVersion, nonKDim);
95161
}
96162

97163
using OperandTypesVector = SmallVector<Type, 4>;
@@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
259325
return castedTensor;
260326
}
261327

262-
class BlockedToMFMA : public RewritePattern {
328+
class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
263329
int mfmaVersion;
264-
int enforcedNonKDim;
330+
int nonKDim;
265331
int kPack;
266332

267333
public:
268-
BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack)
269-
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
270-
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}
334+
BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack,
335+
PatternBenefit benefit = 1)
336+
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
337+
nonKDim(nonKDim), kPack(kPack) {}
271338

272339
bool isSecondDot(tt::DotOp &dotOp) const {
273340
auto filter = [&dotOp](Operation *op) {
@@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern {
285352
return false;
286353
}
287354

288-
/// @brief Choose MFMA instruction parameters
289-
/// @param dot target dot operation
290-
/// @return MfmaInsn or failure
291-
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot) const {
292-
// number of matrix elements along k dim per one MFMA intruction
293-
unsigned kDim = 0;
294-
auto opType = cast<RankedTensorType>(dot.getA().getType());
295-
auto dataTypeA = opType.getElementType();
296-
auto dataTypeB =
297-
cast<RankedTensorType>(dot.getB().getType()).getElementType();
298-
299-
auto resType = cast<RankedTensorType>(dot.getD().getType());
300-
auto resShape = resType.getShape();
301-
auto rank = resShape.size();
302-
auto M = resShape[rank - 2];
303-
auto N = resShape[rank - 1];
304-
305-
unsigned mDim = 0;
306-
unsigned nDim = 0;
307-
if (enforcedNonKDim != 0) {
308-
mDim = enforcedNonKDim;
309-
nDim = enforcedNonKDim;
310-
} else {
311-
int minSize = std::min(M, N);
312-
if (minSize >= 32) {
313-
mDim = 32;
314-
nDim = 32;
315-
}
316-
if (minSize >= 16 && minSize < 32) {
317-
mDim = 16;
318-
nDim = 16;
319-
}
320-
if (minSize < 16) {
321-
if (M < 16 && N >= 64) {
322-
mDim = 4;
323-
nDim = 64;
324-
} else if (M >= 64 && N < 16) {
325-
mDim = 64;
326-
nDim = 4;
327-
} else {
328-
assert(opType.getShape()[rank - 1] >= 64 &&
329-
"k should be at least 64 to use this layout");
330-
mDim = 4;
331-
nDim = 4;
332-
}
333-
}
334-
}
335-
assert(mDim != 0 && nDim != 0);
336-
337-
auto maybeMfmaInsn =
338-
MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion);
339-
if (failed(maybeMfmaInsn))
340-
llvm::report_fatal_error("No match found in MFMA database\n");
341-
342-
kDim = maybeMfmaInsn->getKDim();
343-
assert(kDim != 0);
344-
assert(M % mDim == 0 && N % nDim == 0);
345-
assert(opType.getShape()[rank - 1] % kDim == 0);
346-
return maybeMfmaInsn;
347-
}
348-
349-
LogicalResult matchAndRewrite(Operation *op,
355+
LogicalResult matchAndRewrite(tt::DotOp dotOp,
350356
PatternRewriter &rewriter) const override {
351-
auto dotOp = cast<tt::DotOp>(op);
352-
353357
RankedTensorType oldRetType = dotOp.getType();
354358
if (!oldRetType.getEncoding() ||
355359
!isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding()))
356360
return failure();
361+
if (!isa_and_nonnull<BlockedEncodingAttr>(dotOp.getType().getEncoding()))
362+
return rewriter.notifyMatchFailure(
363+
dotOp, "expected blocked encoding result tensor");
357364

358365
if (!supportMFMA(dotOp))
359366
return failure();
@@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern {
362369

363370
// get MFMA encoding for the given number of warps
364371
auto retShape = oldRetType.getShape();
365-
auto mod = op->getParentOfType<ModuleOp>();
372+
auto mod = dotOp->getParentOfType<ModuleOp>();
366373
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
367374

368375
// operands
@@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern {
374381

375382
ttg::AMDMfmaEncodingAttr mfmaEnc;
376383

377-
auto mfmaInstr = chooseMfmaInstruction(dotOp);
384+
auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim);
378385
auto mDim = mfmaInstr.value().getMDim();
379386
auto nDim = mfmaInstr.value().getNDim();
380387
auto kDim = mfmaInstr.value().getKDim();
@@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern {
397404
mfmaAccType = rewriter.getF32Type();
398405

399406
// convert accumulator
400-
auto oldAcc = dotOp.getOperand(2);
407+
auto oldAcc = dotOp.getC();
401408
auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType);
402409

403410
// Here is a brief explanation of kWidth, kBase, and kDim
@@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern {
456463
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
457464
oldRetType.getElementType());
458465

459-
rewriter.replaceOp(op, dotOutput);
466+
rewriter.replaceOp(dotOp, dotOutput);
460467

461468
return success();
462469
}
463470
};
471+
464472
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
465473
Type promotedType) {
466474
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
@@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
566574
});
567575
}
568576

569-
class BlockedToWMMA : public RewritePattern {
577+
class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
570578
int wmmaVersion;
571579

572580
public:
573-
BlockedToWMMA(MLIRContext *context, int wmmaVersion)
574-
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
575-
wmmaVersion(wmmaVersion) {}
581+
BlockedToWMMA(MLIRContext *context, int wmmaVersion,
582+
PatternBenefit benefit = 1)
583+
: OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {}
576584

577-
LogicalResult matchAndRewrite(Operation *op,
585+
LogicalResult matchAndRewrite(tt::DotOp dotOp,
578586
PatternRewriter &rewriter) const override {
579-
auto ctx = op->getContext();
580-
auto dotOp = cast<tt::DotOp>(op);
587+
auto ctx = dotOp->getContext();
581588

582589
Value a = dotOp.getA();
583590
Value b = dotOp.getB();
@@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern {
603610

604611
if (wmmaVersion == 2 && llvm::isa<FloatType>(oldAType) &&
605612
oldAType.getIntOrFloatBitWidth() == 8) {
606-
return rewriter.notifyMatchFailure(op, "not supported yet");
613+
return rewriter.notifyMatchFailure(dotOp, "not supported yet");
607614
}
608615

609616
// get operand types
@@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern {
612619
return failure();
613620

614621
// get WMMA encoding for the given number of warps
615-
auto mod = op->getParentOfType<ModuleOp>();
622+
auto mod = dotOp->getParentOfType<ModuleOp>();
616623
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
617624

618625
ttg::AMDWmmaEncodingAttr wmmaEnc;
@@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern {
626633
auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc);
627634

628635
// convert accumulator
629-
auto oldAcc = dotOp.getOperand(2);
636+
auto oldAcc = dotOp.getC();
630637
auto newAcc =
631638
convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]);
632639

@@ -653,7 +660,7 @@ class BlockedToWMMA : public RewritePattern {
653660

654661
Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding,
655662
oldRetType.getElementType());
656-
rewriter.replaceOp(op, dotOutput);
663+
rewriter.replaceOp(dotOp, dotOutput);
657664
return success();
658665
}
659666
};

0 commit comments

Comments
 (0)