99#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
1010#include " mlir/Dialect/EmitC/IR/EmitC.h"
1111#include " mlir/IR/IRMapping.h"
12+ #include " mlir/IR/Location.h"
1213#include " mlir/IR/PatternMatch.h"
14+ #include " llvm/ADT/STLExtras.h"
1315
1416namespace mlir {
1517namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
2426 Location loc = op->getLoc ();
2527
2628 builder.setInsertionPointAfter (op);
27- auto expressionOp = emitc::ExpressionOp::create (builder, loc, resultType);
29+ auto expressionOp =
30+ emitc::ExpressionOp::create (builder, loc, resultType, op->getOperands ());
2831
2932 // Replace all op's uses with the new expression's result.
3033 result.replaceAllUsesWith (expressionOp.getResult ());
3134
32- // Create an op to yield op's value.
33- Region ®ion = expressionOp.getRegion ();
34- Block &block = region.emplaceBlock ();
35+ Block &block = expressionOp.createBody ();
36+ IRMapping mapper;
37+ for (auto [operand, arg] :
38+ llvm::zip (expressionOp.getOperands (), block.getArguments ()))
39+ mapper.map (operand, arg);
3540 builder.setInsertionPointToEnd (&block);
36- auto yieldOp = emitc::YieldOp::create (builder, loc, result);
3741
38- // Move op into the new expression.
39- op->moveBefore (yieldOp );
42+ Operation *rootOp = builder. clone (*op, mapper);
43+ op->erase ( );
4044
45+ // Create an op to yield op's value.
46+ emitc::YieldOp::create (builder, loc, rootOp->getResults ()[0 ]);
4147 return expressionOp;
4248}
4349
@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
5359 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
5460 LogicalResult matchAndRewrite (ExpressionOp expressionOp,
5561 PatternRewriter &rewriter) const override {
56- bool anythingFolded = false ;
57- for (Operation &op : llvm::make_early_inc_range (
58- expressionOp.getBody ()->without_terminator ())) {
59- // Don't fold expressions whose result value has its address taken.
60- auto applyOp = dyn_cast<emitc::ApplyOp>(op);
61- if (applyOp && applyOp.getApplicableOperator () == " &" )
62- continue ;
63-
64- for (Value operand : op.getOperands ()) {
65- auto usedExpression = operand.getDefiningOp <ExpressionOp>();
66- if (!usedExpression)
67- continue ;
68-
69- // Don't fold expressions with multiple users: assume any
70- // re-materialization was done separately.
71- if (!usedExpression.getResult ().hasOneUse ())
72- continue ;
73-
74- // Don't fold expressions with side effects.
75- if (usedExpression.hasSideEffects ())
76- continue ;
77-
78- // Fold the used expression into this expression by cloning all
79- // instructions in the used expression just before the operation using
80- // its value.
81- rewriter.setInsertionPoint (&op);
82- IRMapping mapper;
83- for (Operation &opToClone :
84- usedExpression.getBody ()->without_terminator ()) {
85- Operation *clone = rewriter.clone (opToClone, mapper);
86- mapper.map (&opToClone, clone);
87- }
88-
89- Operation *expressionRoot = usedExpression.getRootOp ();
90- Operation *clonedExpressionRootOp = mapper.lookup (expressionRoot);
91- assert (clonedExpressionRootOp &&
92- " Expected cloned expression root to be in mapper" );
93- assert (clonedExpressionRootOp->getNumResults () == 1 &&
94- " Expected cloned root to have a single result" );
95-
96- rewriter.replaceOp (usedExpression, clonedExpressionRootOp);
97- anythingFolded = true ;
98- }
62+ Block *expressionBody = expressionOp.getBody ();
63+ ExpressionOp usedExpression;
64+ SetVector<Value> foldedOperands;
65+
66+ auto takesItsOperandsAddress = [](Operation *user) {
67+ auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68+ return applyOp && applyOp.getApplicableOperator () == " &" ;
69+ };
70+
71+ // Select as expression to fold the first operand expression that
72+ // - doesn't have its result value's address taken,
73+ // - has a single user: assume any re-materialization was done separately,
74+ // - has no side effects,
75+ // and save all other operands to be used later as operands in the folded
76+ // expression.
77+ for (auto [operand, arg] : llvm::zip (expressionOp.getOperands (),
78+ expressionBody->getArguments ())) {
79+ ExpressionOp operandExpression = operand.getDefiningOp <ExpressionOp>();
80+ if (usedExpression || !operandExpression ||
81+ llvm::any_of (arg.getUsers (), takesItsOperandsAddress) ||
82+ !operandExpression.getResult ().hasOneUse () ||
83+ operandExpression.hasSideEffects ())
84+ foldedOperands.insert (operand);
85+ else
86+ usedExpression = operandExpression;
9987 }
100- return anythingFolded ? success () : failure ();
88+
89+ // If no operand expression was selected, bail out.
90+ if (!usedExpression)
91+ return failure ();
92+
93+ // Collect additional operands from the folded expression.
94+ for (Value operand : usedExpression.getOperands ())
95+ foldedOperands.insert (operand);
96+
97+ // Create a new expression to hold the folding result.
98+ rewriter.setInsertionPointAfter (expressionOp);
99+ auto foldedExpression = emitc::ExpressionOp::create (
100+ rewriter, expressionOp.getLoc (), expressionOp.getResult ().getType (),
101+ foldedOperands.getArrayRef (), expressionOp.getDoNotInline ());
102+ Block &foldedExpressionBody = foldedExpression.createBody ();
103+
104+ // Map each operand of the new expression to its matching block argument.
105+ IRMapping mapper;
106+ for (auto [operand, arg] : llvm::zip (foldedExpression.getOperands (),
107+ foldedExpressionBody.getArguments ()))
108+ mapper.map (operand, arg);
109+
110+ // Prepare to fold the used expression and the matched expression into the
111+ // newly created folded expression.
112+ auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113+ bool withTerminator) {
114+ Block *expressionToFoldBody = expressionToFold.getBody ();
115+ for (auto [operand, arg] :
116+ llvm::zip (expressionToFold.getOperands (),
117+ expressionToFoldBody->getArguments ())) {
118+ mapper.map (arg, mapper.lookup (operand));
119+ }
120+
121+ for (Operation &opToClone : expressionToFoldBody->without_terminator ())
122+ rewriter.clone (opToClone, mapper);
123+
124+ if (withTerminator)
125+ rewriter.clone (*expressionToFoldBody->getTerminator (), mapper);
126+ };
127+ rewriter.setInsertionPointToStart (&foldedExpressionBody);
128+
129+ // First, fold the used expression into the new expression and map its
130+ // result to the clone of its root operation within the new expression.
131+ foldExpression (usedExpression, /* withTerminator=*/ false );
132+ Operation *expressionRoot = usedExpression.getRootOp ();
133+ Operation *clonedExpressionRootOp = mapper.lookup (expressionRoot);
134+ assert (clonedExpressionRootOp &&
135+ " Expected cloned expression root to be in mapper" );
136+ assert (clonedExpressionRootOp->getNumResults () == 1 &&
137+ " Expected cloned root to have a single result" );
138+ mapper.map (usedExpression.getResult (),
139+ clonedExpressionRootOp->getResults ()[0 ]);
140+
141+ // Now fold the matched expression into the new expression.
142+ foldExpression (expressionOp, /* withTerminator=*/ true );
143+
144+ // Complete the rewrite.
145+ rewriter.replaceOp (expressionOp, foldedExpression);
146+ rewriter.eraseOp (usedExpression);
147+
148+ return success ();
101149 }
102150};
103151
0 commit comments