@@ -37,45 +37,6 @@ namespace {
3737//
3838// -----------------------------------------------------------------------------
3939
40- // dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
41- class ConvertDotConvert : public RewritePattern {
42- public:
43- ConvertDotConvert (MLIRContext *context)
44- : RewritePattern(ConvertLayoutOp::getOperationName(), 1 , context) {}
45-
46- LogicalResult matchAndRewrite (Operation *op,
47- PatternRewriter &rewriter) const override {
48- auto dstOp = cast<ConvertLayoutOp>(op);
49- auto dotOp = dstOp.getSrc ().getDefiningOp <DotOp>();
50- if (!dotOp)
51- return failure ();
52- if (std::distance (dstOp->user_begin (), dstOp->user_end ()) != 1 ||
53- std::distance (dotOp->user_begin (), dotOp->user_end ()) != 1 )
54- return failure ();
55- auto cvtOp = dotOp.getOperand (2 ).getDefiningOp <ConvertLayoutOp>();
56- if (!cvtOp)
57- return failure ();
58- if (!cvtOp.getSrc ().getDefiningOp <LoadOp>())
59- return failure ();
60- RankedTensorType dstTy = dstOp.getType ();
61- RankedTensorType srcTy = cvtOp.getSrc ().getType ();
62- if (dstTy != srcTy)
63- return failure ();
64-
65- auto _0f = rewriter.create <arith::ConstantOp>(
66- op->getLoc (), dstTy.getElementType (),
67- rewriter.getZeroAttr (dstTy.getElementType ()));
68- auto _0 = rewriter.create <SplatOp>(op->getLoc (), dotOp.getType (), _0f);
69- auto newDot = rewriter.create <DotOp>(
70- op->getLoc (), dotOp.getType (), dotOp.getOperand (0 ), dotOp.getOperand (1 ),
71- _0, dotOp.getInputPrecision (), dotOp.getMaxNumImpreciseAcc ());
72- auto newCvt = rewriter.create <ConvertLayoutOp>(op->getLoc (), dstTy,
73- newDot.getResult ());
74- rewriter.replaceOpWithNewOp <arith::AddFOp>(op, newCvt, cvtOp.getSrc ());
75- return success ();
76- }
77- };
78-
7940// The current algorithm works by analyzing the IR and doing a one-shot rewrite
8041// based on the analysis. The algorithm is as follows.
8142//
@@ -1459,16 +1420,6 @@ class TritonIntelGPURemoveLayoutConversionsPass
14591420 m.dump ();
14601421 });
14611422
1462- RewritePatternSet decomposePatterns (context);
1463- decomposePatterns.add <ConvertDotConvert>(context);
1464- if (applyPatternsGreedily (m, std::move (decomposePatterns)).failed ()) {
1465- signalPassFailure ();
1466- }
1467- LLVM_DEBUG ({
1468- DBGS () << " Module after decomposing dot-converts:\n " ;
1469- m.dump ();
1470- });
1471-
14721423 // 4. Apply clean up patterns to remove remove dead convert and dead code
14731424 // generated by the previous transformations.
14741425 RewritePatternSet cleanUpPatterns2 (context);
0 commit comments