55#include " mlir/Support/LogicalResult.h"
66#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
77#include " triton/Analysis/Utility.h"
8+ #include " triton/Conversion/TritonGPUToLLVM/Utility.h"
9+ #include " triton/Dialect/Triton/IR/Dialect.h"
810#include " triton/Dialect/TritonGPU/IR/Dialect.h"
911#include < memory>
1012
@@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) {
3638 return 0 ;
3739}
3840
39- SmallVector<unsigned , 2 > warpsPerTile (tt::DotOp dotOp,
40- const ArrayRef<int64_t > shape,
41- int numWarps,
42- SmallVector<int64_t , 2 > shapePerWarp) {
41+ SmallVector<unsigned , 3 >
42+ warpsPerTile (Operation *dotOp, ArrayRef<int64_t > shape, int numWarps,
43+ std::pair<int64_t , int64_t > shapePerWarp) {
4344 auto rank = shape.size ();
4445 // Early exit for batched matmul
4546 if (rank == 3 )
4647 return {(unsigned )numWarps, 1 , 1 };
4748
48- auto filter = [& dotOp](Operation *op) {
49+ auto filter = [dotOp](Operation *op) {
4950 return op->getParentRegion () == dotOp->getParentRegion ();
5051 };
5152 ForwardSliceOptions fwdOpt;
@@ -55,17 +56,17 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
5556 bwdOpt.filter = filter;
5657 auto slices = getSlice (dotOp, bwdOpt, fwdOpt);
5758 for (Operation *op : slices)
58- if (isa<tt::DotOp>(op ) && (op != dotOp))
59+ if (op-> hasTrait <OpTrait::DotLike>( ) && (op != dotOp))
5960 return {(unsigned )numWarps, 1 };
6061
6162 SmallVector<int64_t , 2 > tensorShape = {shape[0 ], shape[1 ]};
6263 SmallVector<unsigned , 2 > ret = {1 , 1 };
6364 do {
6465 if (ret[0 ] * ret[1 ] >= numWarps)
6566 break ;
66- if (tensorShape[0 ] / (shapePerWarp[ 0 ] * 2 ) / ret[0 ] >=
67- tensorShape[1 ] / shapePerWarp[ 1 ] / ret[1 ]) {
68- if (ret[0 ] < tensorShape[0 ] / shapePerWarp[ 0 ] ) {
67+ if (tensorShape[0 ] / (shapePerWarp. first * 2 ) / ret[0 ] >=
68+ tensorShape[1 ] / shapePerWarp. second / ret[1 ]) {
69+ if (ret[0 ] < tensorShape[0 ] / shapePerWarp. first ) {
6970 ret[0 ] *= 2 ;
7071 } else
7172 ret[1 ] *= 2 ;
@@ -74,24 +75,89 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
7475 }
7576 } while (true );
7677
77- if (ret[1 ] * shapePerWarp[ 1 ] > tensorShape[1 ]) {
78+ if (ret[1 ] * shapePerWarp. second > tensorShape[1 ]) {
7879 return {ret[1 ], ret[0 ]};
7980 }
8081
8182 return ret;
8283}
8384
84- SmallVector<unsigned , 2 >
85- warpsPerTileMFMA (tt::DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
86- SmallVector <int64_t , 2 > shapePerWarp) {
85+ SmallVector<unsigned , 3 >
86+ warpsPerTileMFMA (Operation * dotOp, ArrayRef<int64_t > shape, int numWarps,
87+ std::pair <int64_t , int64_t > shapePerWarp) {
8788 return warpsPerTile (dotOp, shape, numWarps, shapePerWarp);
8889}
8990
90- SmallVector<unsigned , 2 >
91- warpsPerTileWMMA (tt::DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps) {
92- return warpsPerTile (dotOp, shape, numWarps,
93- {ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr ()[0 ],
94- ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr ()[1 ]});
91+ SmallVector<unsigned , 3 >
92+ warpsPerTileWMMA (Operation *dotOp, ArrayRef<int64_t > shape, int numWarps) {
93+ auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr ();
94+ return warpsPerTile (dotOp, shape, numWarps, {mnk[0 ], mnk[1 ]});
95+ }
96+
97+ // Chooses a proper MFMA instruction that can used to compute the given dot op.
98+ // If enforcedNonKDim is not zero, it will be used to overwrite the default
99+ // logic to chose a MFMA with matching M/N dim.
100+ FailureOr<MfmaInsn> chooseMfmaInstruction (RankedTensorType cType,
101+ Type aElemType, Type bElemType,
102+ int inputKSize, int mfmaVersion,
103+ int enforcedNonKDim) {
104+ // number of matrix elements along k dim per one MFMA intruction
105+ unsigned kDim = 0 ;
106+
107+ auto resShape = cType.getShape ();
108+ auto rank = resShape.size ();
109+ auto M = resShape[rank - 2 ];
110+ auto N = resShape[rank - 1 ];
111+
112+ unsigned mDim = 0 ;
113+ unsigned nDim = 0 ;
114+ if (enforcedNonKDim != 0 ) {
115+ mDim = nDim = enforcedNonKDim;
116+ } else {
117+ int minSize = std::min (M, N);
118+ if (minSize >= 32 ) {
119+ mDim = 32 ;
120+ nDim = 32 ;
121+ }
122+ if (minSize >= 16 && minSize < 32 ) {
123+ mDim = 16 ;
124+ nDim = 16 ;
125+ }
126+ if (minSize < 16 ) {
127+ if (M < 16 && N >= 64 ) {
128+ mDim = 4 ;
129+ nDim = 64 ;
130+ } else if (M >= 64 && N < 16 ) {
131+ mDim = 64 ;
132+ nDim = 4 ;
133+ } else {
134+ assert (inputKSize >= 64 &&
135+ " k should be at least 64 to use this layout" );
136+ mDim = 4 ;
137+ nDim = 4 ;
138+ }
139+ }
140+ }
141+ assert (mDim != 0 && nDim != 0 );
142+
143+ auto maybeMfmaInsn =
144+ MfmaInsn::selectMfma (mDim , nDim, aElemType, bElemType, mfmaVersion);
145+ if (failed (maybeMfmaInsn))
146+ llvm::report_fatal_error (" No match found in MFMA database\n " );
147+
148+ kDim = maybeMfmaInsn->getKDim ();
149+ assert (kDim != 0 );
150+ assert (M % mDim == 0 && N % nDim == 0 );
151+ assert (inputKSize % kDim == 0 );
152+ return maybeMfmaInsn;
153+ }
154+
155+ FailureOr<MfmaInsn> chooseMfmaInstruction (tt::DotOp dot, int mfmaVersion,
156+ int nonKDim) {
157+ RankedTensorType aType = dot.getA ().getType ();
158+ return chooseMfmaInstruction (dot.getC ().getType (), aType.getElementType (),
159+ dot.getB ().getType ().getElementType (),
160+ aType.getShape ().back (), mfmaVersion, nonKDim);
95161}
96162
97163using OperandTypesVector = SmallVector<Type, 4 >;
@@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
259325 return castedTensor;
260326}
261327
262- class BlockedToMFMA : public RewritePattern {
328+ class BlockedToMFMA : public OpRewritePattern <tt::DotOp> {
263329 int mfmaVersion;
264- int enforcedNonKDim ;
330+ int nonKDim ;
265331 int kPack ;
266332
267333public:
268- BlockedToMFMA (MLIRContext *context, int mfmaVersion, int nonKDim, int kPack )
269- : RewritePattern(tt::DotOp::getOperationName(), 2 , context),
270- mfmaVersion (mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack ) {}
334+ BlockedToMFMA (MLIRContext *context, int mfmaVersion, int nonKDim, int kPack ,
335+ PatternBenefit benefit = 1 )
336+ : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
337+ nonKDim (nonKDim), kPack(kPack ) {}
271338
272339 bool isSecondDot (tt::DotOp &dotOp) const {
273340 auto filter = [&dotOp](Operation *op) {
@@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern {
285352 return false ;
286353 }
287354
288- // / @brief Choose MFMA instruction parameters
289- // / @param dot target dot operation
290- // / @return MfmaInsn or failure
291- FailureOr<MfmaInsn> chooseMfmaInstruction (tt::DotOp dot) const {
292- // number of matrix elements along k dim per one MFMA intruction
293- unsigned kDim = 0 ;
294- auto opType = cast<RankedTensorType>(dot.getA ().getType ());
295- auto dataTypeA = opType.getElementType ();
296- auto dataTypeB =
297- cast<RankedTensorType>(dot.getB ().getType ()).getElementType ();
298-
299- auto resType = cast<RankedTensorType>(dot.getD ().getType ());
300- auto resShape = resType.getShape ();
301- auto rank = resShape.size ();
302- auto M = resShape[rank - 2 ];
303- auto N = resShape[rank - 1 ];
304-
305- unsigned mDim = 0 ;
306- unsigned nDim = 0 ;
307- if (enforcedNonKDim != 0 ) {
308- mDim = enforcedNonKDim;
309- nDim = enforcedNonKDim;
310- } else {
311- int minSize = std::min (M, N);
312- if (minSize >= 32 ) {
313- mDim = 32 ;
314- nDim = 32 ;
315- }
316- if (minSize >= 16 && minSize < 32 ) {
317- mDim = 16 ;
318- nDim = 16 ;
319- }
320- if (minSize < 16 ) {
321- if (M < 16 && N >= 64 ) {
322- mDim = 4 ;
323- nDim = 64 ;
324- } else if (M >= 64 && N < 16 ) {
325- mDim = 64 ;
326- nDim = 4 ;
327- } else {
328- assert (opType.getShape ()[rank - 1 ] >= 64 &&
329- " k should be at least 64 to use this layout" );
330- mDim = 4 ;
331- nDim = 4 ;
332- }
333- }
334- }
335- assert (mDim != 0 && nDim != 0 );
336-
337- auto maybeMfmaInsn =
338- MfmaInsn::selectMfma (mDim , nDim, dataTypeA, dataTypeB, mfmaVersion);
339- if (failed (maybeMfmaInsn))
340- llvm::report_fatal_error (" No match found in MFMA database\n " );
341-
342- kDim = maybeMfmaInsn->getKDim ();
343- assert (kDim != 0 );
344- assert (M % mDim == 0 && N % nDim == 0 );
345- assert (opType.getShape ()[rank - 1 ] % kDim == 0 );
346- return maybeMfmaInsn;
347- }
348-
349- LogicalResult matchAndRewrite (Operation *op,
355+ LogicalResult matchAndRewrite (tt::DotOp dotOp,
350356 PatternRewriter &rewriter) const override {
351- auto dotOp = cast<tt::DotOp>(op);
352-
353357 RankedTensorType oldRetType = dotOp.getType ();
354358 if (!oldRetType.getEncoding () ||
355359 !isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding ()))
356360 return failure ();
361+ if (!isa_and_nonnull<BlockedEncodingAttr>(dotOp.getType ().getEncoding ()))
362+ return rewriter.notifyMatchFailure (
363+ dotOp, " expected blocked encoding result tensor" );
357364
358365 if (!supportMFMA (dotOp))
359366 return failure ();
@@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern {
362369
363370 // get MFMA encoding for the given number of warps
364371 auto retShape = oldRetType.getShape ();
365- auto mod = op ->getParentOfType <ModuleOp>();
372+ auto mod = dotOp ->getParentOfType <ModuleOp>();
366373 int numWarps = ttg::TritonGPUDialect::getNumWarps (mod);
367374
368375 // operands
@@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern {
374381
375382 ttg::AMDMfmaEncodingAttr mfmaEnc;
376383
377- auto mfmaInstr = chooseMfmaInstruction (dotOp);
384+ auto mfmaInstr = chooseMfmaInstruction (dotOp, mfmaVersion, nonKDim );
378385 auto mDim = mfmaInstr.value ().getMDim ();
379386 auto nDim = mfmaInstr.value ().getNDim ();
380387 auto kDim = mfmaInstr.value ().getKDim ();
@@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern {
397404 mfmaAccType = rewriter.getF32Type ();
398405
399406 // convert accumulator
400- auto oldAcc = dotOp.getOperand ( 2 );
407+ auto oldAcc = dotOp.getC ( );
401408 auto newAcc = convertAndCastTensor (rewriter, oldAcc, mfmaEnc, mfmaAccType);
402409
403410 // Here is a brief explanation of kWidth, kBase, and kDim
@@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern {
456463 convertAndCastTensor (rewriter, newDot, oldRetType.getEncoding (),
457464 oldRetType.getElementType ());
458465
459- rewriter.replaceOp (op , dotOutput);
466+ rewriter.replaceOp (dotOp , dotOutput);
460467
461468 return success ();
462469 }
463470};
471+
464472static Value promoteOperand (OpBuilder &builder, Location loc, Value operand,
465473 Type promotedType) {
466474 Type tensorPromotedType = cast<RankedTensorType>(operand.getType ())
@@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
566574 });
567575}
568576
569- class BlockedToWMMA : public RewritePattern {
577+ class BlockedToWMMA : public OpRewritePattern <tt::DotOp> {
570578 int wmmaVersion;
571579
572580public:
573- BlockedToWMMA (MLIRContext *context, int wmmaVersion)
574- : RewritePattern(tt::DotOp::getOperationName(), 2 , context),
575- wmmaVersion (wmmaVersion) {}
581+ BlockedToWMMA (MLIRContext *context, int wmmaVersion,
582+ PatternBenefit benefit = 1 )
583+ : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {}
576584
577- LogicalResult matchAndRewrite (Operation *op ,
585+ LogicalResult matchAndRewrite (tt::DotOp dotOp ,
578586 PatternRewriter &rewriter) const override {
579- auto ctx = op->getContext ();
580- auto dotOp = cast<tt::DotOp>(op);
587+ auto ctx = dotOp->getContext ();
581588
582589 Value a = dotOp.getA ();
583590 Value b = dotOp.getB ();
@@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern {
603610
604611 if (wmmaVersion == 2 && llvm::isa<FloatType>(oldAType) &&
605612 oldAType.getIntOrFloatBitWidth () == 8 ) {
606- return rewriter.notifyMatchFailure (op , " not supported yet" );
613+ return rewriter.notifyMatchFailure (dotOp , " not supported yet" );
607614 }
608615
609616 // get operand types
@@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern {
612619 return failure ();
613620
614621 // get WMMA encoding for the given number of warps
615- auto mod = op ->getParentOfType <ModuleOp>();
622+ auto mod = dotOp ->getParentOfType <ModuleOp>();
616623 int numWarps = ttg::TritonGPUDialect::getNumWarps (mod);
617624
618625 ttg::AMDWmmaEncodingAttr wmmaEnc;
@@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern {
626633 auto newRetType = RankedTensorType::get (retShape, operandTypes[3 ], wmmaEnc);
627634
628635 // convert accumulator
629- auto oldAcc = dotOp.getOperand ( 2 );
636+ auto oldAcc = dotOp.getC ( );
630637 auto newAcc =
631638 convertAndCastTensor (rewriter, oldAcc, wmmaEnc, operandTypes[2 ]);
632639
@@ -653,7 +660,7 @@ class BlockedToWMMA : public RewritePattern {
653660
654661 Value dotOutput = convertAndCastTensor (rewriter, newDot, oldRetEncoding,
655662 oldRetType.getElementType ());
656- rewriter.replaceOp (op , dotOutput);
663+ rewriter.replaceOp (dotOp , dotOutput);
657664 return success ();
658665 }
659666};
0 commit comments