@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
6363
6464 for (OpResult result : op.getResults ()) {
6565 Type resultType = result.getType ();
66+ Type varType = emitc::LValueType::get (resultType);
6667 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
6768 emitc::VariableOp var =
68- rewriter.create <emitc::VariableOp>(loc, resultType , noInit);
69+ rewriter.create <emitc::VariableOp>(loc, varType , noInit);
6970 resultVariables.push_back (var);
7071 }
7172
@@ -76,57 +77,98 @@ static SmallVector<Value> createVariablesForResults(T op,
7677// the current insertion point of given rewriter.
7778static void assignValues (ValueRange values, SmallVector<Value> &variables,
7879 PatternRewriter &rewriter, Location loc) {
79- for (auto [value, var] : llvm::zip (values, variables))
80- rewriter.create <emitc::AssignOp>(loc, var, value);
80+ for (auto [value, var] : llvm::zip (values, variables)) {
81+ assert (isa<emitc::LValueType>(var.getType ()) &&
82+ " expected var to be an lvalue type" );
83+ assert (!isa<emitc::LValueType>(value.getType ()) &&
84+ " expected value to not be an lvalue type" );
85+ auto assign = rewriter.create <emitc::AssignOp>(loc, var, value);
86+
87+ // TODO: Make sure this is safe, as this moves operations with memory
88+ // effects.
89+ if (auto op = dyn_cast_if_present<emitc::LValueToRValueOp>(
90+ value.getDefiningOp ())) {
91+ rewriter.moveOpBefore (op, assign);
92+ }
93+ }
8194}
8295
83- static void lowerYield (SmallVector<Value> &resultVariables ,
84- PatternRewriter &rewriter, scf::YieldOp yield) {
96+ static void lowerYield (SmallVector<Value> &variables, PatternRewriter &rewriter ,
97+ scf::YieldOp yield) {
8598 Location loc = yield.getLoc ();
8699 ValueRange operands = yield.getOperands ();
87100
88101 OpBuilder::InsertionGuard guard (rewriter);
89102 rewriter.setInsertionPoint (yield);
90103
91- assignValues (operands, resultVariables , rewriter, loc);
104+ assignValues (operands, variables , rewriter, loc);
92105
93106 rewriter.create <emitc::YieldOp>(loc);
94107 rewriter.eraseOp (yield);
95108}
96109
110+ static void replaceUsers (PatternRewriter &rewriter,
111+ SmallVector<Value> fromValues,
112+ SmallVector<Value> toValues) {
113+ OpBuilder::InsertionGuard guard (rewriter);
114+ for (auto [from, to] : llvm::zip (fromValues, toValues)) {
115+ assert (from.getType () == cast<emitc::LValueType>(to.getType ()).getValue () &&
116+ " expected types to match" );
117+
118+ for (OpOperand &operand : llvm::make_early_inc_range (from.getUses ())) {
119+ Operation *op = operand.getOwner ();
120+ // Skip yield ops, as these get rewritten anyways.
121+ if (isa<scf::YieldOp>(op)) {
122+ continue ;
123+ }
124+ Location loc = op->getLoc ();
125+
126+ rewriter.setInsertionPoint (op);
127+ Value rValue =
128+ rewriter.create <emitc::LValueToRValueOp>(loc, from.getType (), to);
129+ operand.set (rValue);
130+ }
131+ }
132+ }
133+
97134LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
98135 PatternRewriter &rewriter) const {
99136 Location loc = forOp.getLoc ();
100137
101- // Create an emitc::variable op for each result. These variables will be
102- // assigned to by emitc::assign ops within the loop body.
103- SmallVector<Value> resultVariables =
104- createVariablesForResults (forOp, rewriter);
105- SmallVector<Value> iterArgsVariables =
106- createVariablesForResults (forOp, rewriter);
138+ // Create an emitc::variable op for each result. These variables will be used
139+ // for the results of the operations as well as the iter_args. They are
140+ // assigned to by emitc::assign ops before the loop and at the end of the loop
141+ // body.
142+ SmallVector<Value> variables = createVariablesForResults (forOp, rewriter);
107143
108- assignValues (forOp.getInits (), iterArgsVariables, rewriter, loc);
144+ // Assign initial values to the iter arg variables.
145+ assignValues (forOp.getInits (), variables, rewriter, loc);
109146
110- emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
111- loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
147+ // Replace users of the iter args with variables.
148+ SmallVector<Value> iterArgs;
149+ for (BlockArgument arg : forOp.getRegionIterArgs ()) {
150+ iterArgs.push_back (arg);
151+ }
112152
113- Block *loweredBody = loweredFor. getBody ( );
153+ replaceUsers (rewriter, iterArgs, variables );
114154
115- // Erase the auto-generated terminator for the lowered for op.
116- rewriter.eraseOp (loweredBody->getTerminator ());
155+ emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
156+ loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
157+ rewriter.eraseBlock (loweredFor.getBody ());
117158
118- SmallVector<Value> replacingValues;
119- replacingValues.push_back (loweredFor.getInductionVar ());
120- replacingValues.append (iterArgsVariables.begin (), iterArgsVariables.end ());
159+ rewriter.inlineRegionBefore (forOp.getRegion (), loweredFor.getRegion (),
160+ loweredFor.getRegion ().end ());
161+ Operation *terminator = loweredFor.getRegion ().back ().getTerminator ();
162+ lowerYield (variables, rewriter, cast<scf::YieldOp>(terminator));
121163
122- rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
123- lowerYield (iterArgsVariables, rewriter,
124- cast<scf::YieldOp>(loweredBody->getTerminator ()));
164+ // Erase block arguments for iter_args.
165+ loweredFor.getRegion ().back ().eraseArguments (1 , variables.size ());
125166
126- // Copy iterArgs into results after the for loop.
127- assignValues (iterArgsVariables, resultVariables, rewriter, loc);
167+ // Replace all users of the results with lazily created lvalue-to-rvalue
168+ // ops.
169+ replaceUsers (rewriter, forOp.getResults (), variables);
128170
129- rewriter.replaceOp (forOp, resultVariables );
171+ rewriter.eraseOp (forOp);
130172 return success ();
131173}
132174
@@ -167,7 +209,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
167209
168210 bool hasElseBlock = !elseRegion.empty ();
169211
170- auto loweredIf =
212+ emitc::IfOp loweredIf =
171213 rewriter.create <emitc::IfOp>(loc, ifOp.getCondition (), false , false );
172214
173215 Region &loweredThenRegion = loweredIf.getThenRegion ();
@@ -178,7 +220,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
178220 lowerRegion (elseRegion, loweredElseRegion);
179221 }
180222
181- rewriter.replaceOp (ifOp, resultVariables);
223+ // Replace all users of the results with lazily created lvalue-to-rvalue
224+ // ops.
225+ replaceUsers (rewriter, ifOp.getResults (), resultVariables);
226+
227+ rewriter.eraseOp (ifOp);
182228 return success ();
183229}
184230
0 commit comments