Skip to content

Commit b980165

Browse files
authored
[intel] Sync RemoveLayoutConversions.cpp with Triton using d86ae7b commit (#3357)
triton-lang/triton@d86ae7b Signed-off-by: Anatoly Myachev <[email protected]>
1 parent d11660f commit b980165

File tree

1 file changed

+0
-49
lines changed

1 file changed

+0
-49
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -37,45 +37,6 @@ namespace {
3737
//
3838
// -----------------------------------------------------------------------------
3939

40-
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
41-
class ConvertDotConvert : public RewritePattern {
42-
public:
43-
ConvertDotConvert(MLIRContext *context)
44-
: RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {}
45-
46-
LogicalResult matchAndRewrite(Operation *op,
47-
PatternRewriter &rewriter) const override {
48-
auto dstOp = cast<ConvertLayoutOp>(op);
49-
auto dotOp = dstOp.getSrc().getDefiningOp<DotOp>();
50-
if (!dotOp)
51-
return failure();
52-
if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 ||
53-
std::distance(dotOp->user_begin(), dotOp->user_end()) != 1)
54-
return failure();
55-
auto cvtOp = dotOp.getOperand(2).getDefiningOp<ConvertLayoutOp>();
56-
if (!cvtOp)
57-
return failure();
58-
if (!cvtOp.getSrc().getDefiningOp<LoadOp>())
59-
return failure();
60-
RankedTensorType dstTy = dstOp.getType();
61-
RankedTensorType srcTy = cvtOp.getSrc().getType();
62-
if (dstTy != srcTy)
63-
return failure();
64-
65-
auto _0f = rewriter.create<arith::ConstantOp>(
66-
op->getLoc(), dstTy.getElementType(),
67-
rewriter.getZeroAttr(dstTy.getElementType()));
68-
auto _0 = rewriter.create<SplatOp>(op->getLoc(), dotOp.getType(), _0f);
69-
auto newDot = rewriter.create<DotOp>(
70-
op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1),
71-
_0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
72-
auto newCvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), dstTy,
73-
newDot.getResult());
74-
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getSrc());
75-
return success();
76-
}
77-
};
78-
7940
// The current algorithm works by analyzing the IR and doing a one-shot rewrite
8041
// based on the analysis. The algorithm is as follows.
8142
//
@@ -1459,16 +1420,6 @@ class TritonIntelGPURemoveLayoutConversionsPass
14591420
m.dump();
14601421
});
14611422

1462-
RewritePatternSet decomposePatterns(context);
1463-
decomposePatterns.add<ConvertDotConvert>(context);
1464-
if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
1465-
signalPassFailure();
1466-
}
1467-
LLVM_DEBUG({
1468-
DBGS() << "Module after decomposing dot-converts:\n";
1469-
m.dump();
1470-
});
1471-
14721423
// 4. Apply clean up patterns to remove remove dead convert and dead code
14731424
// generated by the previous transformations.
14741425
RewritePatternSet cleanUpPatterns2(context);

0 commit comments

Comments
 (0)