Skip to content

Commit 0fd0c2c

Browse files
do not build argument materializations anymore
1 parent a44f0b8 commit 0fd0c2c

File tree

6 files changed

+574
-342
lines changed

6 files changed

+574
-342
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,8 @@ class ConversionPattern : public RewritePattern {
583583
: RewritePattern(std::forward<Args>(args)...),
584584
typeConverter(&typeConverter) {}
585585

586-
static SmallVector<Value>
587-
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands);
586+
SmallVector<Value>
587+
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const;
588588

589589
protected:
590590
/// An optional type converter for use by this pattern.
@@ -858,6 +858,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
858858
/// PatternRewriter hook for replacing an operation.
859859
void replaceOp(Operation *op, Operation *newOp) override;
860860

861+
/// PatternRewriter hook for replacing an operation.
862+
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
863+
861864
/// PatternRewriter hook for erasing a dead operation. The uses of this
862865
/// operation *must* be made dead by the end of the conversion process,
863866
/// otherwise an assert will be issued.

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
/*
156157
// Argument materializations convert from the new block argument types
157158
// (multiple SSA values that make up a memref descriptor) back to the
158159
// original block argument type. The dialect conversion framework will then
@@ -198,16 +199,62 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
198199
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199200
.getResult(0);
200201
});
202+
203+
*/
201204
// Add generic source and target materializations to handle cases where
202205
// non-LLVM types persist after an LLVM conversion.
203206
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
204207
ValueRange inputs, Location loc) {
205-
if (inputs.size() != 1)
206-
return Value();
208+
//if (inputs.size() != 1)
209+
// return Value();
207210

208211
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
209212
.getResult(0);
210213
});
214+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
215+
ValueRange inputs, Location loc) {
216+
if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();
217+
218+
Value desc;
219+
if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
220+
// This is a bare pointer. We allow bare pointers only for function entry
221+
// blocks.
222+
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
223+
if (!barePtr)
224+
return Value();
225+
Block *block = barePtr.getOwner();
226+
if (!block->isEntryBlock() ||
227+
!isa<FunctionOpInterface>(block->getParentOp()))
228+
return Value();
229+
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
230+
inputs[0]);
231+
} else {
232+
//llvm::errs() << "pack elems: " << inputs.size() << "\n";
233+
//llvm::errs() << inputs[0] << "\n";
234+
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
235+
//llvm::errs() << "done packing\n";
236+
}
237+
// An argument materialization must return a value of type `resultType`,
238+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
239+
// original memref type.
240+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
241+
.getResult(0);
242+
});
243+
addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
244+
ValueRange inputs, Location loc) {
245+
if (inputs.size() == 1) {
246+
// Bare pointers are not supported for unranked memrefs because a
247+
// memref descriptor cannot be built just from a bare pointer.
248+
return Value();
249+
}
250+
Value desc =
251+
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
252+
// An argument materialization must return a value of type
253+
// `resultType`, so insert a cast from the memref descriptor type
254+
// (!llvm.struct) to the original memref type.
255+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
256+
.getResult(0);
257+
});
211258
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212259
ValueRange inputs, Location loc) {
213260
if (inputs.size() != 1)
@@ -216,6 +263,51 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
216263
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217264
.getResult(0);
218265
});
266+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
267+
ValueRange inputs,
268+
Location loc, Type originalType) -> Value {
269+
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
270+
if (!originalType) {
271+
llvm::errs() << " -- no orig\n";
272+
return Value();
273+
}
274+
if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
275+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
276+
if (inputs.size() == 1) {
277+
Value input = inputs.front();
278+
if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
279+
if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
280+
input = castOp.getInputs()[0];
281+
}
282+
}
283+
if (!isa<LLVM::LLVMPointerType>(input.getType()))
284+
return Value();
285+
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
286+
if (!barePtr)
287+
return Value();
288+
Block *block = barePtr.getOwner();
289+
if (!block->isEntryBlock() ||
290+
!isa<FunctionOpInterface>(block->getParentOp()))
291+
return Value();
292+
// Bare ptr
293+
return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
294+
input);
295+
}
296+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
297+
}
298+
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
299+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
300+
if (inputs.size() == 1) {
301+
// Bare pointers are not supported for unranked memrefs because a
302+
// memref descriptor cannot be built just from a bare pointer.
303+
return Value();
304+
}
305+
return UnrankedMemRefDescriptor::pack(builder, loc, *this,
306+
memrefType, inputs);
307+
}
308+
309+
return Value();
310+
});
219311

220312
// Integer memory spaces map to themselves.
221313
addTypeAttributeConversion(

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,16 @@ using namespace mlir::scf;
1616

1717
namespace {
1818

19-
// Unpacks the single unrealized_conversion_cast using the list of inputs
20-
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
21-
static void unpackUnrealizedConversionCast(Value v,
22-
SmallVectorImpl<Value> &unpacked) {
23-
if (auto cast =
24-
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
25-
if (cast.getInputs().size() != 1) {
26-
// 1 : N type conversion.
27-
unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
28-
return;
29-
}
30-
}
31-
// 1 : 1 type conversion.
32-
unpacked.push_back(v);
19+
static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) {
20+
SmallVector<Value> result;
21+
for (ArrayRef<Value> v : values)
22+
llvm::append_range(result, v);
23+
return result;
24+
}
25+
26+
static Value getSingleValue(ArrayRef<Value> values) {
27+
assert(values.size() == 1 && "expected single value");
28+
return values.front();
3329
}
3430

3531
// CRTP
@@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
4036
public:
4137
using OpConversionPattern<SourceOp>::typeConverter;
4238
using OpConversionPattern<SourceOp>::OpConversionPattern;
43-
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
39+
using OneToNOpAdaptor =
40+
typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
4441

4542
//
4643
// Derived classes should provide the following method which performs the
4744
// actual conversion. It should return std::nullopt upon conversion failure
4845
// and return the converted operation upon success.
4946
//
50-
// std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
51-
// ConversionPatternRewriter &rewriter,
52-
// TypeRange dstTypes) const;
47+
// std::optional<SourceOp> convertSourceOp(
48+
// SourceOp op, OneToNOpAdaptor adaptor,
49+
// ConversionPatternRewriter &rewriter,
50+
// TypeRange dstTypes) const;
5351

5452
LogicalResult
55-
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
53+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
5654
ConversionPatternRewriter &rewriter) const override {
5755
SmallVector<Type> dstTypes;
5856
SmallVector<unsigned> offsets;
@@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
7371
return rewriter.notifyMatchFailure(op, "could not convert operation");
7472

7573
// Packs the return value.
76-
SmallVector<Value> packedRets;
74+
SmallVector<ValueRange> packedRets;
7775
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
7876
unsigned start = offsets[i - 1], end = offsets[i];
7977
unsigned len = end - start;
8078
ValueRange mappedValue = newOp->getResults().slice(start, len);
81-
if (len != 1) {
82-
// 1 : N type conversion.
83-
Type origType = op.getResultTypes()[i - 1];
84-
Value mat = typeConverter->materializeSourceConversion(
85-
rewriter, op.getLoc(), origType, mappedValue);
86-
if (!mat) {
87-
return rewriter.notifyMatchFailure(
88-
op, "Failed to materialize 1:N type conversion");
89-
}
90-
packedRets.push_back(mat);
91-
} else {
92-
// 1 : 1 type conversion.
93-
packedRets.push_back(mappedValue.front());
94-
}
79+
packedRets.push_back(mappedValue);
9580
}
9681

97-
rewriter.replaceOp(op, packedRets);
82+
rewriter.replaceOpWithMultiple(op, packedRets);
9883
return success();
9984
}
10085
};
@@ -105,7 +90,7 @@ class ConvertForOpTypes
10590
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
10691

10792
// The callback required by CRTP.
108-
std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
93+
std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
10994
ConversionPatternRewriter &rewriter,
11095
TypeRange dstTypes) const {
11196
// Create a empty new op and inline the regions from the old op.
@@ -129,16 +114,13 @@ class ConvertForOpTypes
129114
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
130115
return std::nullopt;
131116

132-
// Unpacked the iteration arguments.
133-
SmallVector<Value> flatArgs;
134-
for (Value arg : adaptor.getInitArgs())
135-
unpackUnrealizedConversionCast(arg, flatArgs);
136-
137117
// We can not do clone as the number of result types after conversion
138118
// might be different.
139-
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
140-
adaptor.getUpperBound(),
141-
adaptor.getStep(), flatArgs);
119+
ForOp newOp = rewriter.create<ForOp>(
120+
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
121+
getSingleValue(adaptor.getUpperBound()),
122+
getSingleValue(adaptor.getStep()),
123+
flattenValues(adaptor.getInitArgs()));
142124

143125
// Reserve whatever attributes in the original op.
144126
newOp->setAttrs(op->getAttrs());
@@ -160,12 +142,12 @@ class ConvertIfOpTypes
160142
public:
161143
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
162144

163-
std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
145+
std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
164146
ConversionPatternRewriter &rewriter,
165147
TypeRange dstTypes) const {
166148

167-
IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
168-
adaptor.getCondition(), true);
149+
IfOp newOp = rewriter.create<IfOp>(
150+
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
169151
newOp->setAttrs(op->getAttrs());
170152

171153
// We do not need the empty blocks created by rewriter.
@@ -189,15 +171,11 @@ class ConvertWhileOpTypes
189171
public:
190172
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191173

192-
std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
174+
std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
193175
ConversionPatternRewriter &rewriter,
194176
TypeRange dstTypes) const {
195-
// Unpacked the iteration arguments.
196-
SmallVector<Value> flatArgs;
197-
for (Value arg : adaptor.getOperands())
198-
unpackUnrealizedConversionCast(arg, flatArgs);
199-
200-
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
177+
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
178+
flattenValues(adaptor.getOperands()));
201179

202180
for (auto i : {0u, 1u}) {
203181
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
218196
public:
219197
using OpConversionPattern::OpConversionPattern;
220198
LogicalResult
221-
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
199+
matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
222200
ConversionPatternRewriter &rewriter) const override {
223-
SmallVector<Value> unpackedYield;
224-
for (Value operand : adaptor.getOperands())
225-
unpackUnrealizedConversionCast(operand, unpackedYield);
226-
227-
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
201+
rewriter.replaceOpWithNewOp<scf::YieldOp>(
202+
op, flattenValues(adaptor.getOperands()));
228203
return success();
229204
}
230205
};
@@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
235210
public:
236211
using OpConversionPattern<ConditionOp>::OpConversionPattern;
237212
LogicalResult
238-
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
213+
matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
239214
ConversionPatternRewriter &rewriter) const override {
240-
SmallVector<Value> unpackedYield;
241-
for (Value operand : adaptor.getOperands())
242-
unpackUnrealizedConversionCast(operand, unpackedYield);
243-
244-
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
215+
rewriter.modifyOpInPlace(
216+
op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
245217
return success();
246218
}
247219
};

0 commit comments

Comments
 (0)