Skip to content

Commit 1814386

Browse files
authored
Revert "Revert "[BACKEND] Fix a missed transpose optimization during refactor (#5236)"" (#2922)
This reverts commit d1aaea2. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8f50525 commit 1814386

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
148148
LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
149149
PatternRewriter &rewriter) const override {
150150
// Match outerCvt(trans(innerCvt(x))).
151-
auto trans = cvtOp.getSrc().getDefiningOp<MemDescTransOp>();
151+
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
152152
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
153153
return failure();
154154

155-
auto srcTy = dyn_cast<RankedTensorType>(trans.getSrc().getType());
155+
RankedTensorType srcTy = trans.getSrc().getType();
156156

157157
if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
158158
srcTy = srcCvt.getSrc().getType();

test/TritonGPU/dot-operands.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
282282
tt.return %r : tensor<128x64xf32, #mma>
283283
}
284284
}
285+
286+
// -----
287+
288+
#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
289+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
290+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
291+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
292+
// CHECK-LABEL: mmav2_reorder_transpose
293+
// CHECK: ttg.local_alloc
294+
// CHECK: ttg.memdesc_trans
295+
// CHECK: ttg.local_load
296+
// CHECK: tt.dot
297+
tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
298+
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked>
299+
%cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
300+
%r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
301+
tt.return %r : tensor<128x64xf32, #mma>
302+
}
303+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,29 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern<ExpandDimsOp> {
272272
return success();
273273
}
274274
};
275+
struct MemDescTransOpConversion
276+
: public ConvertOpToLLVMPattern<MemDescTransOp> {
277+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
278+
LogicalResult
279+
matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor,
280+
ConversionPatternRewriter &rewriter) const override {
281+
Location loc = op->getLoc();
282+
auto resultTy = cast<TensorOrMemDesc>(op.getType());
283+
auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding());
284+
auto llvmElemTy =
285+
getTypeConverter()->convertType(resultTy.getElementType());
286+
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
287+
llvmElemTy, rewriter);
288+
auto dstSmemObj = SharedMemoryObject(
289+
srcSmemObj.base, srcSmemObj.baseElemType,
290+
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
291+
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
292+
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
293+
rewriter.replaceOp(op, retVal);
294+
return success();
295+
}
296+
};
297+
275298
struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
276299
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
277300
LogicalResult
@@ -413,6 +436,7 @@ void mlir::triton::intel::populateViewOpToLLVMPatterns(
413436
patterns.add<CatOpConversion>(typeConverter, benefit);
414437
patterns.add<JoinOpConversion>(typeConverter, benefit);
415438
patterns.add<SplitOpConversion>(typeConverter, benefit);
439+
patterns.add<MemDescTransOpConversion>(typeConverter, benefit);
416440
patterns.add<TransOpConversion>(typeConverter, benefit);
417441
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
418442
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);

0 commit comments

Comments
 (0)