Skip to content

Commit 7326995

Browse files
authored
Fix yield conversion of scf.if/scf.for to emitc (#401)
* Fix conversion for scf.for and scf.if
1 parent 72cbeca commit 7326995

File tree

3 files changed

+123
-25
lines changed

3 files changed

+123
-25
lines changed

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,31 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
7979

8080
// Create a series of assign ops assigning given values to given variables at
8181
// the current insertion point of given rewriter.
82-
static void assignValues(ValueRange values, SmallVector<Value> &variables,
82+
static void assignValues(ValueRange values, ValueRange variables,
8383
ConversionPatternRewriter &rewriter, Location loc) {
8484
for (auto [value, var] : llvm::zip(values, variables))
8585
rewriter.create<emitc::AssignOp>(loc, var, value);
8686
}
8787

88-
static void lowerYield(SmallVector<Value> &resultVariables,
89-
ConversionPatternRewriter &rewriter,
90-
scf::YieldOp yield) {
88+
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
89+
ConversionPatternRewriter &rewriter,
90+
scf::YieldOp yield) {
9191
Location loc = yield.getLoc();
92-
ValueRange operands = yield.getOperands();
9392

9493
OpBuilder::InsertionGuard guard(rewriter);
9594
rewriter.setInsertionPoint(yield);
9695

97-
assignValues(operands, resultVariables, rewriter, loc);
96+
SmallVector<Value> yieldOperands;
97+
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
98+
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
99+
}
100+
101+
assignValues(yieldOperands, resultVariables, rewriter, loc);
98102

99103
rewriter.create<emitc::YieldOp>(loc);
100104
rewriter.eraseOp(yield);
105+
106+
return success();
101107
}
102108

103109
LogicalResult
@@ -118,22 +124,32 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
118124
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
119125
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
120126

121-
// Propagate any attributes from the ODS forOp to the lowered emitc::for op.
122-
loweredFor->setAttrs(forOp->getAttrs());
123-
124127
Block *loweredBody = loweredFor.getBody();
125128

126129
// Erase the auto-generated terminator for the lowered for op.
127130
rewriter.eraseOp(loweredBody->getTerminator());
128131

132+
// Convert the original region types into the new types by adding unrealized
133+
// casts in the beginning of the loop. This performs the conversion in place.
134+
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
135+
*getTypeConverter(), nullptr))) {
136+
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
137+
}
138+
139+
// Register the replacements for the block arguments and inline the body of
140+
// the scf.for loop into the body of the emitc::for loop.
141+
Block *scfBody = &(forOp.getRegion().front());
129142
SmallVector<Value> replacingValues;
130143
replacingValues.push_back(loweredFor.getInductionVar());
131144
replacingValues.append(resultVariables.begin(), resultVariables.end());
145+
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
146+
147+
auto result = lowerYield(forOp, resultVariables, rewriter,
148+
cast<scf::YieldOp>(loweredBody->getTerminator()));
132149

133-
Block *adaptorBody = &(adaptor.getRegion().front());
134-
rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues);
135-
lowerYield(resultVariables, rewriter,
136-
cast<scf::YieldOp>(loweredBody->getTerminator()));
150+
if (failed(result)) {
151+
return result;
152+
}
137153

138154
rewriter.replaceOp(forOp, resultVariables);
139155
return success();
@@ -169,11 +185,16 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
169185
// emitc::if regions, but the scf::yield is replaced not only with an
170186
// emitc::yield, but also with a sequence of emitc::assign ops that set the
171187
// yielded values into the result variables.
172-
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
173-
Region &loweredRegion) {
188+
auto lowerRegion = [&resultVariables, &rewriter,
189+
&ifOp](Region &region, Region &loweredRegion) {
174190
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
175191
Operation *terminator = loweredRegion.back().getTerminator();
176-
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
192+
auto result = lowerYield(ifOp, resultVariables, rewriter,
193+
cast<scf::YieldOp>(terminator));
194+
if (failed(result)) {
195+
return result;
196+
}
197+
return success();
177198
};
178199

179200
Region &thenRegion = adaptor.getThenRegion();
@@ -185,11 +206,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
185206
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
186207

187208
Region &loweredThenRegion = loweredIf.getThenRegion();
188-
lowerRegion(thenRegion, loweredThenRegion);
209+
auto result = lowerRegion(thenRegion, loweredThenRegion);
210+
if (failed(result)) {
211+
return result;
212+
}
189213

190214
if (hasElseBlock) {
191215
Region &loweredElseRegion = loweredIf.getElseRegion();
192-
lowerRegion(elseRegion, loweredElseRegion);
216+
auto result = lowerRegion(elseRegion, loweredElseRegion);
217+
if (failed(result)) {
218+
return result;
219+
}
193220
}
194221

195222
rewriter.replaceOp(ifOp, resultVariables);

mlir/test/Conversion/SCFToEmitC/for.mlir

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,55 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
9999
// CHECK-NEXT: return %[[VAL_4]] : f32
100100
// CHECK-NEXT: }
101101

102-
func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) {
103-
scf.for %i0 = %arg0 to %arg1 step %arg2 {
104-
%c1 = arith.constant 1 : index
105-
} {test.value = 5 : index}
106-
return
102+
func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
103+
%zero = arith.constant 0 : index
104+
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
105+
scf.yield %acc : index
106+
}
107+
return %r : index
107108
}
108-
// CHECK-LABEL: func.func @loop_with_attr
109-
// CHECK: {test.value = 5 : index}
109+
110+
// CHECK-LABEL: func.func @for_yield_index(
111+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
112+
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
113+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
114+
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
115+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
116+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
117+
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
118+
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
119+
// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
120+
// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
121+
// CHECK: }
122+
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
123+
// CHECK: return %[[VAL_8]] : index
124+
// CHECK: }
125+
126+
127+
func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
128+
%zero = arith.constant 0 : index
129+
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
130+
%sn = arith.addi %acc, %acc : index
131+
scf.yield %sn: index
132+
}
133+
return %r : index
134+
}
135+
136+
// CHECK-LABEL: func.func @for_yield_update_loop_carried_var(
137+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
138+
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
139+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
140+
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
141+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
142+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
143+
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
144+
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
145+
// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
146+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
147+
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
148+
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
149+
// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
150+
// CHECK: }
151+
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
152+
// CHECK: return %[[VAL_9]] : index
153+
// CHECK: }

mlir/test/Conversion/SCFToEmitC/if.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
6868
// CHECK-NEXT: }
6969
// CHECK-NEXT: return
7070
// CHECK-NEXT: }
71+
72+
73+
func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
74+
%0 = arith.constant 0 : index
75+
%1 = arith.constant 1 : index
76+
%x = scf.if %arg0 -> (index) {
77+
scf.yield %0 : index
78+
} else {
79+
scf.yield %1 : index
80+
}
81+
return
82+
}
83+
84+
// CHECK: func.func @test_if_yield_index(
85+
// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) {
86+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
87+
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
88+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
89+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
90+
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
91+
// CHECK: emitc.if %[[ARG_0]] {
92+
// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
93+
// CHECK: } else {
94+
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
95+
// CHECK: }
96+
// CHECK: return
97+
// CHECK: }

0 commit comments

Comments
 (0)