Skip to content

Commit ca055da

Browse files
committed
Fix conversion
1 parent 2113e3c commit ca055da

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/IRMapping.h"
2222
#include "mlir/IR/MLIRContext.h"
2323
#include "mlir/IR/PatternMatch.h"
24+
#include "mlir/IR/Value.h"
2425
#include "mlir/Transforms/DialectConversion.h"
2526
#include "mlir/Transforms/OneToNTypeConversion.h"
2627
#include "mlir/Transforms/Passes.h"
@@ -79,22 +80,36 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
7980

8081
// Create a series of assign ops assigning given values to given variables at
8182
// the current insertion point of given rewriter.
82-
static void assignValues(ValueRange values, SmallVector<Value> &variables,
83-
ConversionPatternRewriter &rewriter, Location loc) {
83+
static void assignValues(ValueRange values, ValueRange variables,
84+
ConversionPatternRewriter &rewriter, Location loc,
85+
const TypeConverter *typeConverter = nullptr) {
8486
for (auto [value, var] : llvm::zip(values, variables))
8587
rewriter.create<emitc::AssignOp>(loc, var, value);
8688
}
8789

88-
static void lowerYield(SmallVector<Value> &resultVariables,
89-
ConversionPatternRewriter &rewriter,
90-
scf::YieldOp yield) {
90+
static void lowerYield(ValueRange resultVariables,
91+
ConversionPatternRewriter &rewriter, scf::YieldOp yield,
92+
const TypeConverter *typeConverter) {
9193
Location loc = yield.getLoc();
92-
ValueRange operands = yield.getOperands();
9394

9495
OpBuilder::InsertionGuard guard(rewriter);
9596
rewriter.setInsertionPoint(yield);
9697

97-
assignValues(operands, resultVariables, rewriter, loc);
98+
SmallVector<Value> yieldOperands;
99+
for (auto originalOperand : yield.getOperands()) {
100+
Value operand = originalOperand;
101+
102+
if (typeConverter && !typeConverter->isLegal(operand.getType())) {
103+
Type resultType = typeConverter->convertType(operand.getType());
104+
auto castToTarget =
105+
rewriter.create<UnrealizedConversionCastOp>(loc, resultType, operand);
106+
operand = castToTarget.getResult(0);
107+
}
108+
109+
yieldOperands.push_back(operand);
110+
}
111+
112+
assignValues(yieldOperands, resultVariables, rewriter, loc);
98113

99114
rewriter.create<emitc::YieldOp>(loc);
100115
rewriter.eraseOp(yield);
@@ -118,22 +133,29 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
118133
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
119134
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
120135

121-
// Propagate any attributes from the ODS forOp to the lowered emitc::for op.
122-
loweredFor->setAttrs(forOp->getAttrs());
123-
124136
Block *loweredBody = loweredFor.getBody();
125137

126138
// Erase the auto-generated terminator for the lowered for op.
127139
rewriter.eraseOp(loweredBody->getTerminator());
128140

141+
// Convert the original region types into the new types by adding unrealized
142+
// casts in the begginning of the loop. This performs the conversion in place.
143+
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
144+
*getTypeConverter(), nullptr))) {
145+
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
146+
}
147+
148+
// Register the replacements for the block arguments and inline the body of
149+
// the scf.for loop into the body of the emitc::for loop.
150+
Block *scfBody = &(forOp.getRegion().front());
129151
SmallVector<Value> replacingValues;
130152
replacingValues.push_back(loweredFor.getInductionVar());
131153
replacingValues.append(resultVariables.begin(), resultVariables.end());
154+
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
132155

133-
Block *adaptorBody = &(adaptor.getRegion().front());
134-
rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues);
135156
lowerYield(resultVariables, rewriter,
136-
cast<scf::YieldOp>(loweredBody->getTerminator()));
157+
cast<scf::YieldOp>(loweredBody->getTerminator()),
158+
getTypeConverter());
137159

138160
rewriter.replaceOp(forOp, resultVariables);
139161
return success();
@@ -169,11 +191,12 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
169191
// emitc::if regions, but the scf::yield is replaced not only with an
170192
// emitc::yield, but also with a sequence of emitc::assign ops that set the
171193
// yielded values into the result variables.
172-
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
173-
Region &loweredRegion) {
194+
auto lowerRegion = [&resultVariables, &rewriter,
195+
this](Region &region, Region &loweredRegion) {
174196
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
175197
Operation *terminator = loweredRegion.back().getTerminator();
176-
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
198+
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator),
199+
getTypeConverter());
177200
};
178201

179202
Region &thenRegion = adaptor.getThenRegion();

mlir/test/Conversion/SCFToEmitC/for.mlir

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,49 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
9999
// CHECK-NEXT: return %[[VAL_4]] : f32
100100
// CHECK-NEXT: }
101101

102-
func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) {
103-
scf.for %i0 = %arg0 to %arg1 step %arg2 {
104-
%c1 = arith.constant 1 : index
105-
} {test.value = 5 : index}
106-
return
102+
func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
103+
%zero = arith.constant 0 : index
104+
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
105+
scf.yield %acc : index
106+
}
107+
return %r : index
107108
}
108-
// CHECK-LABEL: func.func @loop_with_attr
109-
// CHECK: {test.value = 5 : index}
109+
110+
// CHECK: func.func @for_yield_index(%arg0: index, %arg1: index, %arg2: index) -> index {
111+
// CHECK: %0 = builtin.unrealized_conversion_cast %arg2 : index to !emitc.size_t
112+
// CHECK: %1 = builtin.unrealized_conversion_cast %arg1 : index to !emitc.size_t
113+
// CHECK: %2 = builtin.unrealized_conversion_cast %arg0 : index to !emitc.size_t
114+
// CHECK: %c0 = arith.constant 0 : index
115+
// CHECK: %3 = builtin.unrealized_conversion_cast %c0 : index to !emitc.size_t
116+
// CHECK: %4 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
117+
// CHECK: emitc.assign %3 : !emitc.size_t to %4 : !emitc.size_t
118+
// CHECK: emitc.for %arg3 = %2 to %1 step %0 {
119+
// CHECK: %6 = builtin.unrealized_conversion_cast %4 : !emitc.size_t to index
120+
// CHECK: %7 = builtin.unrealized_conversion_cast %6 : index to !emitc.size_t
121+
// CHECK: emitc.assign %7 : !emitc.size_t to %4 : !emitc.size_t
122+
// CHECK: }
123+
// CHECK: %5 = builtin.unrealized_conversion_cast %4 : !emitc.size_t to index
124+
// CHECK: return %5 : index
125+
// CHECK: }
126+
127+
128+
func.func @for_yield_i32(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
129+
%zero = arith.constant 0 : i32
130+
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> i32 {
131+
scf.yield %acc : i32
132+
}
133+
return %r : i32
134+
}
135+
136+
// CHECK: func.func @for_yield_i32(%arg0: index, %arg1: index, %arg2: index) -> i32 {
137+
// CHECK: %0 = builtin.unrealized_conversion_cast %arg2 : index to !emitc.size_t
138+
// CHECK: %1 = builtin.unrealized_conversion_cast %arg1 : index to !emitc.size_t
139+
// CHECK: %2 = builtin.unrealized_conversion_cast %arg0 : index to !emitc.size_t
140+
// CHECK: %c0_i32 = arith.constant 0 : i32
141+
// CHECK: %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
142+
// CHECK: emitc.assign %c0_i32 : i32 to %3 : i32
143+
// CHECK: emitc.for %arg3 = %2 to %1 step %0 {
144+
// CHECK: emitc.assign %3 : i32 to %3 : i32
145+
// CHECK: }
146+
// CHECK: return %3 : i32
147+
// CHECK: }

mlir/test/Conversion/SCFToEmitC/if.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,29 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
6868
// CHECK-NEXT: }
6969
// CHECK-NEXT: return
7070
// CHECK-NEXT: }
71+
72+
73+
func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
74+
%0 = arith.constant 0 : index
75+
%1 = arith.constant 0 : index
76+
%x = scf.if %arg0 -> (index) {
77+
scf.yield %0 : index
78+
} else {
79+
scf.yield %1 : index
80+
}
81+
return
82+
}
83+
84+
// CHECK:func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
85+
// CHECK: %c0 = arith.constant 0 : index
86+
// CHECK: %c0_0 = arith.constant 0 : index
87+
// CHECK: %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
88+
// CHECK: emitc.if %arg0 {
89+
// CHECK: %1 = builtin.unrealized_conversion_cast %c0 : index to !emitc.size_t
90+
// CHECK: emitc.assign %1 : !emitc.size_t to %0 : !emitc.size_t
91+
// CHECK: } else {
92+
// CHECK: %1 = builtin.unrealized_conversion_cast %c0_0 : index to !emitc.size_t
93+
// CHECK: emitc.assign %1 : !emitc.size_t to %0 : !emitc.size_t
94+
// CHECK: }
95+
// CHECK: return
96+
// CHECK: }

0 commit comments

Comments
 (0)