2121#include " mlir/IR/IRMapping.h"
2222#include " mlir/IR/MLIRContext.h"
2323#include " mlir/IR/PatternMatch.h"
24+ #include " mlir/IR/Value.h"
2425#include " mlir/Transforms/DialectConversion.h"
2526#include " mlir/Transforms/OneToNTypeConversion.h"
2627#include " mlir/Transforms/Passes.h"
@@ -79,22 +80,36 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
7980
8081// Create a series of assign ops assigning given values to given variables at
8182// the current insertion point of given rewriter.
82- static void assignValues (ValueRange values, SmallVector<Value> &variables,
83- ConversionPatternRewriter &rewriter, Location loc) {
83+ static void assignValues (ValueRange values, ValueRange variables,
84+ ConversionPatternRewriter &rewriter, Location loc,
85+ const TypeConverter *typeConverter = nullptr ) {
8486 for (auto [value, var] : llvm::zip (values, variables))
8587 rewriter.create <emitc::AssignOp>(loc, var, value);
8688}
8789
88- static void lowerYield (SmallVector<Value> & resultVariables,
89- ConversionPatternRewriter &rewriter,
90- scf::YieldOp yield ) {
90+ static void lowerYield (ValueRange resultVariables,
91+ ConversionPatternRewriter &rewriter, scf::YieldOp yield,
92+ const TypeConverter *typeConverter ) {
9193 Location loc = yield.getLoc ();
92- ValueRange operands = yield.getOperands ();
9394
9495 OpBuilder::InsertionGuard guard (rewriter);
9596 rewriter.setInsertionPoint (yield);
9697
97- assignValues (operands, resultVariables, rewriter, loc);
98+ SmallVector<Value> yieldOperands;
99+ for (auto originalOperand : yield.getOperands ()) {
100+ Value operand = originalOperand;
101+
102+ if (typeConverter && !typeConverter->isLegal (operand.getType ())) {
103+ Type resultType = typeConverter->convertType (operand.getType ());
104+ auto castToTarget =
105+ rewriter.create <UnrealizedConversionCastOp>(loc, resultType, operand);
106+ operand = castToTarget.getResult (0 );
107+ }
108+
109+ yieldOperands.push_back (operand);
110+ }
111+
112+ assignValues (yieldOperands, resultVariables, rewriter, loc);
98113
99114 rewriter.create <emitc::YieldOp>(loc);
100115 rewriter.eraseOp (yield);
@@ -118,22 +133,29 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
118133 emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
119134 loc, adaptor.getLowerBound (), adaptor.getUpperBound (), adaptor.getStep ());
120135
121- // Propagate any attributes from the ODS forOp to the lowered emitc::for op.
122- loweredFor->setAttrs (forOp->getAttrs ());
123-
124136 Block *loweredBody = loweredFor.getBody ();
125137
126138 // Erase the auto-generated terminator for the lowered for op.
127139 rewriter.eraseOp (loweredBody->getTerminator ());
128140
141+ // Convert the original region types into the new types by adding unrealized
142+ // casts in the begginning of the loop. This performs the conversion in place.
143+ if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
144+ *getTypeConverter (), nullptr ))) {
145+ return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
146+ }
147+
148+ // Register the replacements for the block arguments and inline the body of
149+ // the scf.for loop into the body of the emitc::for loop.
150+ Block *scfBody = &(forOp.getRegion ().front ());
129151 SmallVector<Value> replacingValues;
130152 replacingValues.push_back (loweredFor.getInductionVar ());
131153 replacingValues.append (resultVariables.begin (), resultVariables.end ());
154+ rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
132155
133- Block *adaptorBody = &(adaptor.getRegion ().front ());
134- rewriter.mergeBlocks (adaptorBody, loweredBody, replacingValues);
135156 lowerYield (resultVariables, rewriter,
136- cast<scf::YieldOp>(loweredBody->getTerminator ()));
157+ cast<scf::YieldOp>(loweredBody->getTerminator ()),
158+ getTypeConverter ());
137159
138160 rewriter.replaceOp (forOp, resultVariables);
139161 return success ();
@@ -169,11 +191,12 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
169191 // emitc::if regions, but the scf::yield is replaced not only with an
170192 // emitc::yield, but also with a sequence of emitc::assign ops that set the
171193 // yielded values into the result variables.
172- auto lowerRegion = [&resultVariables, &rewriter](Region ®ion ,
173- Region &loweredRegion) {
194+ auto lowerRegion = [&resultVariables, &rewriter,
195+ this ](Region ®ion, Region &loweredRegion) {
174196 rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
175197 Operation *terminator = loweredRegion.back ().getTerminator ();
176- lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
198+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator),
199+ getTypeConverter ());
177200 };
178201
179202 Region &thenRegion = adaptor.getThenRegion ();
0 commit comments