|
16 | 16 |
|
17 | 17 | namespace { |
18 | 18 |
|
19 | | -using ::mlir::LLVM::getMultiDimOffset; |
20 | | -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; |
21 | | -using ::mlir::LLVM::getWrappedMultiDimOffset; |
22 | | -using ::mlir::LLVM::linearize; |
23 | | - |
| 19 | +using namespace mlir; |
24 | 20 | using namespace mlir::triton::gpu; |
25 | 21 |
|
26 | 22 | // XXX(Keren): A temporary knob to control the use of legacy MMA conversion |
@@ -105,13 +101,14 @@ struct ConvertLayoutOpConversion |
105 | 101 | // of performance issue observed. |
106 | 102 | for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { |
107 | 103 | SmallVector<Value> multiDimOffset = |
108 | | - getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, |
109 | | - multiDimCTAInRepId, shapePerCTATile); |
110 | | - SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset( |
111 | | - rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, |
112 | | - shapePerCTA); |
113 | | - Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, |
114 | | - paddedRepShape, outOrd); |
| 104 | + LLVM::getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, |
| 105 | + type, multiDimCTAInRepId, shapePerCTATile); |
| 106 | + SmallVector<Value> multiDimOffsetWrapped = |
| 107 | + LLVM::getWrappedMultiDimOffset(rewriter, loc, multiDimOffset, |
| 108 | + origRepShape, shapePerCTATile, |
| 109 | + shapePerCTA); |
| 110 | + Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped, |
| 111 | + paddedRepShape, outOrd); |
115 | 112 | auto elemPtrTy = smemBase.getType(); |
116 | 113 | Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); |
117 | 114 | auto vecTy = vec_ty(llvmElemTy, vec); |
@@ -267,7 +264,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion |
267 | 264 | // conversions. TODO(jlebar): Eventually we want this to be the only pattern. |
268 | 265 | explicit ConvertLayoutOpUsingLinearLayoutsConversion( |
269 | 266 | LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, |
270 | | - PatternBenefit benefit = 2) |
| 267 | + PatternBenefit benefit = 1) |
271 | 268 | : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { |
272 | 269 | } |
273 | 270 |
|
@@ -395,16 +392,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion |
395 | 392 | if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) { |
396 | 393 | return failure(); |
397 | 394 | } |
398 | | - // FIXME [Dot LL] Remove this once we implement this trick in LLs |
399 | | - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { |
400 | | - return failure(); |
401 | | - } |
402 | | - |
403 | | - // The following check can be removed when generalized warp shuffle |
404 | | - // conversions are ready: |
405 | | - if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { |
406 | | - return failure(); |
407 | | - } |
408 | 395 |
|
409 | 396 | assert(cvtNeedsSharedMemory(srcTy, dstTy)); |
410 | 397 |
|
@@ -666,22 +653,17 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion |
666 | 653 |
|
667 | 654 | } // namespace |
668 | 655 |
|
669 | | -void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( |
| 656 | +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( |
670 | 657 | LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, |
671 | 658 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 659 | + if (useLegacyMMAConversion) { |
| 660 | + // Prioritize the legacy MMA conversion over the LinearLayout conversion. |
| 661 | + // Only for debugging purposes. |
| 662 | + patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, |
| 663 | + benefit.getBenefit() + 1); |
| 664 | + } |
672 | 665 | patterns.add<ConvertLayoutOpUsingLinearLayoutsConversion>( |
673 | 666 | typeConverter, targetInfo, benefit); |
674 | | -} |
675 | | - |
676 | | -void mlir::triton::populateConvertLayoutOpToLLVMPatterns( |
677 | | - LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, |
678 | | - RewritePatternSet &patterns, PatternBenefit benefit) { |
679 | | - // We prefer using the linear layout conversion, so it gets a higher benefit. |
680 | | - // Eventually the LL conversion will subsume all of the others and be the only |
681 | | - // one left. |
682 | | - mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( |
683 | | - typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); |
684 | 667 | patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>( |
685 | 668 | typeConverter, targetInfo, benefit); |
686 | | - patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit); |
687 | 669 | } |
0 commit comments