@@ -195,6 +195,28 @@ static bool bwdFilter(Operation *op) {
195195 mlir::TypeID::get<arith::ArithDialect>());
196196}
197197
198+ static SmallVector<int , 2 > getTransposeOrder (int rank) {
199+ assert (rank >= 2 );
200+ auto transOrder = llvm::to_vector<2 >(llvm::seq<int >(rank - 2 ));
201+ transOrder.push_back (rank - 1 );
202+ transOrder.push_back (rank - 2 );
203+ return transOrder;
204+ }
205+
206+ static DotOp transposeDotOp (PatternRewriter &rewriter, DotOp dotOp) {
207+ auto rank = dotOp.getResult ().getType ().getRank ();
208+ Value a = dotOp.getA ();
209+ Value b = dotOp.getB ();
210+ Value c = dotOp.getC ();
211+ auto transOrder = getTransposeOrder (rank);
212+ a = rewriter.create <TransOp>(a.getLoc (), a, transOrder);
213+ b = rewriter.create <TransOp>(b.getLoc (), b, transOrder);
214+ c = rewriter.create <TransOp>(c.getLoc (), c, transOrder);
215+ return rewriter.create <DotOp>(dotOp.getLoc (), c.getType (), b, a, c,
216+ dotOp.getInputPrecision (),
217+ dotOp.getMaxNumImpreciseAcc ());
218+ }
219+
198220// Finds the first different bitwidth in the chain of shape-preserving
199221// unary ops that x depends on.
200222// There are two primary scenarios:
@@ -249,29 +271,69 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
249271 return failure ();
250272 }
251273 // TODO: Check data-types and SM compatibility
252- RankedTensorType oldRetType = dotOp.getType ();
253- if (!oldRetType.getEncoding () ||
254- mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding ()))
274+ if (!dotOp.getType ().getEncoding () ||
275+ mlir::isa<NvidiaMmaEncodingAttr>(dotOp.getType ().getEncoding ()))
255276 return failure ();
256277
257- // get MMA encoding for the given number of warps
258- auto retShapePerCTA = getShapePerCTA (oldRetType);
259278 auto mod = dotOp->getParentOfType <mlir::ModuleOp>();
260279 int numWarps = TritonGPUDialect::getNumWarps (mod);
261- auto CTALayout = getCTALayout (oldRetType.getEncoding ());
262-
263280 int versionMajor = getMMAVersionSafe (computeCapability, dotOp);
264281 if (!(versionMajor >= 1 && versionMajor <= 3 ))
265282 return failure ();
266283
267- auto instrShape = mmaVersionToInstrShape (
268- versionMajor, retShapePerCTA, dotOp.getA ().getType ().getElementType (),
269- numWarps);
270- // operands
284+ // If both of the operands are not loads, we fallback to MMAv2
285+ // otherwise the reg-smem roundtrip will tank the MMAv3 performance
286+ auto comesFromLoadOrBlockArg = [](Value v) -> bool {
287+ // Peel out the original cvt dot_op<..., #blocked>
288+ // and any other potential cvt/trans ops
289+ while (true ) {
290+ if (auto cvtOp = v.getDefiningOp <ConvertLayoutOp>()) {
291+ v = cvtOp.getSrc ();
292+ continue ;
293+ }
294+ if (auto transOp = v.getDefiningOp <TransOp>()) {
295+ v = transOp.getSrc ();
296+ continue ;
297+ }
298+ break ;
299+ }
300+ // We also accept block arguments as they appear in many MLIR tests
301+ // If this is problematic we can totally drop them
302+ return isa<BlockArgument>(v) ||
303+ (v.getDefiningOp () &&
304+ isa<LoadOp, ExperimentalDescriptorLoadOp>(v.getDefiningOp ()));
305+ };
306+
307+ bool aFromLoad = comesFromLoadOrBlockArg (dotOp.getA ());
308+ bool bFromLoad = comesFromLoadOrBlockArg (dotOp.getB ());
309+ bool transpose = false ;
310+ auto origDotOp = dotOp;
311+ if (aFromLoad && !bFromLoad) {
312+ // If the lhs is not a load and the rhs is, we transpose the inputs
313+ // and the result provided this allows us to use mmav3
314+ // We transpose the result at the end of the rewrite
315+ DotOp transDot = transposeDotOp (rewriter, dotOp);
316+ if (getMMAVersionSafe (computeCapability, transDot) == 3 ) {
317+ dotOp = transDot;
318+ versionMajor = 3 ;
319+ transpose = true ;
320+ }
321+ std::swap (aFromLoad, bFromLoad);
322+ }
323+ // If !aFromLoad && !bFromLoad, we just accept a shmem roundtrip
324+ // for versionMajor == 3
325+
271326 Value a = dotOp.getA ();
272327 Value b = dotOp.getB ();
273- auto oldAType = dotOp.getA ().getType ();
274- auto oldBType = dotOp.getB ().getType ();
328+ auto oldAType = cast<RankedTensorType>(a.getType ());
329+ auto oldBType = cast<RankedTensorType>(b.getType ());
330+ auto oldRetType = cast<RankedTensorType>(dotOp.getType ());
331+
332+ // get MMA encoding for the given number of warps
333+ auto CTALayout = getCTALayout (oldRetType.getEncoding ());
334+ auto retShapePerCTA = getShapePerCTA (oldRetType);
335+ auto instrShape = mmaVersionToInstrShape (
336+ versionMajor, retShapePerCTA, oldAType.getElementType (), numWarps);
275337
276338 assert (versionMajor == 2 || versionMajor == 3 );
277339 int versionMinor = computeCapability == 75 ? 1 : 0 ;
@@ -287,12 +349,28 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
287349 auto newAcc =
288350 rewriter.create <ConvertLayoutOp>(oldAcc.getLoc (), newRetType, oldAcc);
289351
352+ auto getDotOperand = [&](Value v, int opIdx, int bitwidth) {
353+ auto minType =
354+ bitwidth > 0 ? rewriter.getIntegerType (bitwidth) : v.getType ();
355+ auto vType = cast<RankedTensorType>(v.getType ());
356+ auto newVEncoding = DotOperandEncodingAttr::get (
357+ v.getContext (), opIdx, newRetType.getEncoding (), minType);
358+ auto newVType = RankedTensorType::get (
359+ vType.getShape (), vType.getElementType (), newVEncoding);
360+ return rewriter.create <ConvertLayoutOp>(v.getLoc (), newVType, v);
361+ };
362+
290363 Operation *newDot = nullptr ;
291364 if (versionMajor == 3 ) {
292365 auto eltType = dotOp.getA ().getType ().getElementType ();
293366 // In MMAV3 transpose is only supported for f16 and bf16.
294367 bool allowTranspose = eltType.isF16 () || eltType.isBF16 ();
295- a = getSharedMemoryMMAOperand (a, rewriter, 0 , allowTranspose);
368+ if (!aFromLoad) {
369+ int bitwidth = getElementTypeOrSelf (a).getIntOrFloatBitWidth ();
370+ a = getDotOperand (a, 0 , bitwidth);
371+ } else {
372+ a = getSharedMemoryMMAOperand (a, rewriter, 0 , allowTranspose);
373+ }
296374 b = getSharedMemoryMMAOperand (b, rewriter, 1 , allowTranspose);
297375 newDot = rewriter.create <triton::nvidia_gpu::WarpGroupDotOp>(
298376 dotOp.getLoc (), newRetType, a, b, newAcc, nullptr ,
@@ -301,27 +379,21 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
301379 // convert operands
302380 int minBitwidth =
303381 std::min (computeOrigBitWidth (a), computeOrigBitWidth (b));
304- Type minType = rewriter.getIntegerType (minBitwidth);
305- // convert A operand
306- auto newAEncoding = DotOperandEncodingAttr::get (
307- oldAType.getContext (), 0 , newRetType.getEncoding (),
308- minBitwidth > 0 ? minType : oldAType.getElementType ());
309- auto newAType = RankedTensorType::get (
310- oldAType.getShape (), oldAType.getElementType (), newAEncoding);
311- a = rewriter.create <ConvertLayoutOp>(a.getLoc (), newAType, a);
312- // convert B operand
313- auto newBEncoding = DotOperandEncodingAttr::get (
314- oldBType.getContext (), 1 , newRetType.getEncoding (),
315- minBitwidth > 0 ? minType : oldBType.getElementType ());
316- auto newBType = RankedTensorType::get (
317- oldBType.getShape (), oldBType.getElementType (), newBEncoding);
318- b = rewriter.create <ConvertLayoutOp>(b.getLoc (), newBType, b);
382+
383+ a = getDotOperand (a, 0 , minBitwidth);
384+ b = getDotOperand (b, 1 , minBitwidth);
319385 newDot = rewriter.create <DotOp>(dotOp.getLoc (), newRetType, a, b, newAcc,
320386 dotOp.getInputPrecision (),
321387 dotOp.getMaxNumImpreciseAcc ());
322388 }
389+ if (transpose) {
390+ auto rank = dotOp.getResult ().getType ().getRank ();
391+ auto transOrder = getTransposeOrder (rank);
392+ newDot = rewriter.create <TransOp>(newDot->getLoc (), newDot->getResult (0 ),
393+ transOrder);
394+ }
323395 // convert dot instruction
324- rewriter.replaceOpWithNewOp <ConvertLayoutOp>(dotOp, oldRetType ,
396+ rewriter.replaceOpWithNewOp <ConvertLayoutOp>(origDotOp, origDotOp. getType () ,
325397 newDot->getResult (0 ));
326398 return success ();
327399 }
0 commit comments