Skip to content

Commit 820769d

Browse files
do not build argument materializations anymore
1 parent 23e9da2 commit 820769d

File tree

15 files changed

+599
-358
lines changed

15 files changed

+599
-358
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ class ConversionPattern : public RewritePattern {
549549
: RewritePattern(std::forward<Args>(args)...),
550550
typeConverter(&typeConverter) {}
551551

552-
static SmallVector<Value>
553-
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands);
552+
SmallVector<Value>
553+
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const;
554554

555555
protected:
556556
/// An optional type converter for use by this pattern.
@@ -824,6 +824,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
824824
/// PatternRewriter hook for replacing an operation.
825825
void replaceOp(Operation *op, Operation *newOp) override;
826826

827+
/// PatternRewriter hook for replacing an operation.
828+
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
829+
827830
/// PatternRewriter hook for erasing a dead operation. The uses of this
828831
/// operation *must* be made dead by the end of the conversion process,
829832
/// otherwise an assert will be issued.

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
/*
156157
// Argument materializations convert from the new block argument types
157158
// (multiple SSA values that make up a memref descriptor) back to the
158159
// original block argument type. The dialect conversion framework will then
@@ -199,26 +200,67 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
199200
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200201
.getResult(0);
201202
});
203+
204+
*/
202205
// Add generic source and target materializations to handle cases where
203206
// non-LLVM types persist after an LLVM conversion.
204207
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
205208
ValueRange inputs,
206209
Location loc) -> std::optional<Value> {
207-
if (inputs.size() != 1)
208-
return std::nullopt;
209-
210210
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211211
.getResult(0);
212212
});
213213
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214214
ValueRange inputs,
215-
Location loc) -> std::optional<Value> {
216-
if (inputs.size() != 1)
217-
return std::nullopt;
218-
215+
Location loc, Type originalType) -> std::optional<Value> {
219216
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
220217
.getResult(0);
221218
});
219+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
220+
ValueRange inputs,
221+
Location loc, Type originalType) -> std::optional<Value> {
222+
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
223+
if (!originalType) {
224+
llvm::errs() << " -- no orig\n";
225+
return std::nullopt;
226+
}
227+
if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
228+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
229+
if (inputs.size() == 1) {
230+
Value input = inputs.front();
231+
if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
232+
if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
233+
input = castOp.getInputs()[0];
234+
}
235+
}
236+
if (!isa<LLVM::LLVMPointerType>(input.getType()))
237+
return std::nullopt;
238+
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
239+
if (!barePtr)
240+
return std::nullopt;
241+
Block *block = barePtr.getOwner();
242+
if (!block->isEntryBlock() ||
243+
!isa<FunctionOpInterface>(block->getParentOp()))
244+
return std::nullopt;
245+
// Bare ptr
246+
return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
247+
input);
248+
}
249+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
250+
}
251+
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
252+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
253+
if (inputs.size() == 1) {
254+
// Bare pointers are not supported for unranked memrefs because a
255+
// memref descriptor cannot be built just from a bare pointer.
256+
return std::nullopt;
257+
}
258+
return UnrankedMemRefDescriptor::pack(builder, loc, *this,
259+
memrefType, inputs);
260+
}
261+
262+
return std::nullopt;
263+
});
222264

223265
// Integer memory spaces map to themselves.
224266
addTypeAttributeConversion(

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,80 @@ class ExtractStridedMetadataOpLowering
16641664
}
16651665
};
16661666

1667+
struct UnrealizedConversionCastOpLowering
1668+
: public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
1669+
public:
1670+
using ConvertOpToLLVMPattern<
1671+
UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
1672+
1673+
static bool canLower(const LLVMTypeConverter &converter, BaseMemRefType type, ValueRange inputs) {
1674+
if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
1675+
// Bare pointer.
1676+
return true;
1677+
}
1678+
1679+
Type converted = converter.convertType(type);
1680+
if (!converted) return false;
1681+
1682+
auto llvmTy = cast<LLVM::LLVMStructType>(converted);
1683+
SmallVector<Type> flattened;
1684+
std::function<void(Type)> flattenType = [&](Type t) {
1685+
if (auto structTy = dyn_cast<LLVM::LLVMStructType>(t)) {
1686+
for (Type t : structTy.getBody())
1687+
flattenType(t);
1688+
} else if (auto arrayTy = dyn_cast<LLVM::LLVMArrayType>(t)) {
1689+
for (uint64_t i =0, e=arrayTy.getNumElements(); i < e; ++i) {
1690+
flattenType(arrayTy.getElementType());
1691+
}
1692+
} else {
1693+
flattened.push_back(t);
1694+
}
1695+
};
1696+
1697+
flattenType(converted);
1698+
return TypeRange(flattened) == TypeRange(inputs);
1699+
}
1700+
1701+
LogicalResult
1702+
matchAndRewrite(UnrealizedConversionCastOp op,
1703+
OpAdaptor adaptor,
1704+
ConversionPatternRewriter &rewriter) const override {
1705+
if (op->getNumResults() != 1)
1706+
return failure();
1707+
1708+
op.dump();
1709+
Location loc = op.getLoc();
1710+
Value result = op->getResult(0);
1711+
1712+
if (auto rankedMemrefType = dyn_cast<MemRefType>(result.getType())){
1713+
if (op.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(op.getInputs().front().getType())) {
1714+
// Converting a bare pointer.
1715+
Value repl = MemRefDescriptor::fromStaticShape(rewriter, loc, *getTypeConverter(), rankedMemrefType,
1716+
op.getInputs().front());
1717+
rewriter.replaceOp(op, repl);
1718+
return success();
1719+
}
1720+
1721+
// Converting memref descriptor elements.
1722+
// TODO: Check types.
1723+
Value repl = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), rankedMemrefType, op.getInputs());
1724+
rewriter.replaceOp(op, repl);
1725+
return success();
1726+
} else if(auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(result.getType())) {
1727+
if (op.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(op.getInputs().front().getType())) {
1728+
return failure();
1729+
}
1730+
1731+
// TODO: Check types.
1732+
Value repl = UnrankedMemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1733+
unrankedMemrefType, op.getInputs());
1734+
rewriter.replaceOp(op, repl);
1735+
return success();
1736+
}
1737+
1738+
return failure();
1739+
}
1740+
};
16671741
} // namespace
16681742

16691743
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
@@ -1693,6 +1767,7 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
16931767
StoreOpLowering,
16941768
SubViewOpLowering,
16951769
TransposeOpLowering,
1770+
UnrealizedConversionCastOpLowering,
16961771
ViewOpLowering>(converter);
16971772
// clang-format on
16981773
auto allocLowering = converter.getOptions().allocLowering;
@@ -1728,6 +1803,12 @@ struct FinalizeMemRefToLLVMConversionPass
17281803
RewritePatternSet patterns(&getContext());
17291804
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
17301805
LLVMConversionTarget target(getContext());
1806+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>([&](UnrealizedConversionCastOp op) {
1807+
// TODO: Better checking
1808+
auto memrefType =dyn_cast<BaseMemRefType>(op->getResult(0).getType());
1809+
bool canConvert = op->getNumResults() == 1 && memrefType && UnrealizedConversionCastOpLowering::canLower(typeConverter, memrefType, op.getInputs());
1810+
return !canConvert;
1811+
});
17311812
target.addLegalOp<func::FuncOp>();
17321813
if (failed(applyPartialConversion(op, target, std::move(patterns))))
17331814
signalPassFailure();

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,16 @@ using namespace mlir::scf;
1616

1717
namespace {
1818

19-
// Unpacks the single unrealized_conversion_cast using the list of inputs
20-
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
21-
static void unpackUnrealizedConversionCast(Value v,
22-
SmallVectorImpl<Value> &unpacked) {
23-
if (auto cast =
24-
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
25-
if (cast.getInputs().size() != 1) {
26-
// 1 : N type conversion.
27-
unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
28-
return;
29-
}
30-
}
31-
// 1 : 1 type conversion.
32-
unpacked.push_back(v);
19+
static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) {
20+
SmallVector<Value> result;
21+
for (ArrayRef<Value> v : values)
22+
llvm::append_range(result, v);
23+
return result;
24+
}
25+
26+
static Value getSingleValue(ArrayRef<Value> values) {
27+
assert(values.size() == 1 && "expected single value");
28+
return values.front();
3329
}
3430

3531
// CRTP
@@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
4036
public:
4137
using OpConversionPattern<SourceOp>::typeConverter;
4238
using OpConversionPattern<SourceOp>::OpConversionPattern;
43-
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
39+
using OneToNOpAdaptor =
40+
typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
4441

4542
//
4643
// Derived classes should provide the following method which performs the
4744
// actual conversion. It should return std::nullopt upon conversion failure
4845
// and return the converted operation upon success.
4946
//
50-
// std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
51-
// ConversionPatternRewriter &rewriter,
52-
// TypeRange dstTypes) const;
47+
// std::optional<SourceOp> convertSourceOp(
48+
// SourceOp op, OneToNOpAdaptor adaptor,
49+
// ConversionPatternRewriter &rewriter,
50+
// TypeRange dstTypes) const;
5351

5452
LogicalResult
55-
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
53+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
5654
ConversionPatternRewriter &rewriter) const override {
5755
SmallVector<Type> dstTypes;
5856
SmallVector<unsigned> offsets;
@@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
7371
return rewriter.notifyMatchFailure(op, "could not convert operation");
7472

7573
// Packs the return value.
76-
SmallVector<Value> packedRets;
74+
SmallVector<ValueRange> packedRets;
7775
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
7876
unsigned start = offsets[i - 1], end = offsets[i];
7977
unsigned len = end - start;
8078
ValueRange mappedValue = newOp->getResults().slice(start, len);
81-
if (len != 1) {
82-
// 1 : N type conversion.
83-
Type origType = op.getResultTypes()[i - 1];
84-
Value mat = typeConverter->materializeSourceConversion(
85-
rewriter, op.getLoc(), origType, mappedValue);
86-
if (!mat) {
87-
return rewriter.notifyMatchFailure(
88-
op, "Failed to materialize 1:N type conversion");
89-
}
90-
packedRets.push_back(mat);
91-
} else {
92-
// 1 : 1 type conversion.
93-
packedRets.push_back(mappedValue.front());
94-
}
79+
packedRets.push_back(mappedValue);
9580
}
9681

97-
rewriter.replaceOp(op, packedRets);
82+
rewriter.replaceOpWithMultiple(op, packedRets);
9883
return success();
9984
}
10085
};
@@ -105,7 +90,7 @@ class ConvertForOpTypes
10590
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
10691

10792
// The callback required by CRTP.
108-
std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
93+
std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
10994
ConversionPatternRewriter &rewriter,
11095
TypeRange dstTypes) const {
11196
// Create a empty new op and inline the regions from the old op.
@@ -129,16 +114,13 @@ class ConvertForOpTypes
129114
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
130115
return std::nullopt;
131116

132-
// Unpacked the iteration arguments.
133-
SmallVector<Value> flatArgs;
134-
for (Value arg : adaptor.getInitArgs())
135-
unpackUnrealizedConversionCast(arg, flatArgs);
136-
137117
// We can not do clone as the number of result types after conversion
138118
// might be different.
139-
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
140-
adaptor.getUpperBound(),
141-
adaptor.getStep(), flatArgs);
119+
ForOp newOp = rewriter.create<ForOp>(
120+
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
121+
getSingleValue(adaptor.getUpperBound()),
122+
getSingleValue(adaptor.getStep()),
123+
flattenValues(adaptor.getInitArgs()));
142124

143125
// Reserve whatever attributes in the original op.
144126
newOp->setAttrs(op->getAttrs());
@@ -160,12 +142,12 @@ class ConvertIfOpTypes
160142
public:
161143
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
162144

163-
std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
145+
std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
164146
ConversionPatternRewriter &rewriter,
165147
TypeRange dstTypes) const {
166148

167-
IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
168-
adaptor.getCondition(), true);
149+
IfOp newOp = rewriter.create<IfOp>(
150+
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
169151
newOp->setAttrs(op->getAttrs());
170152

171153
// We do not need the empty blocks created by rewriter.
@@ -189,15 +171,11 @@ class ConvertWhileOpTypes
189171
public:
190172
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191173

192-
std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
174+
std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
193175
ConversionPatternRewriter &rewriter,
194176
TypeRange dstTypes) const {
195-
// Unpacked the iteration arguments.
196-
SmallVector<Value> flatArgs;
197-
for (Value arg : adaptor.getOperands())
198-
unpackUnrealizedConversionCast(arg, flatArgs);
199-
200-
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
177+
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
178+
flattenValues(adaptor.getOperands()));
201179

202180
for (auto i : {0u, 1u}) {
203181
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
218196
public:
219197
using OpConversionPattern::OpConversionPattern;
220198
LogicalResult
221-
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
199+
matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
222200
ConversionPatternRewriter &rewriter) const override {
223-
SmallVector<Value> unpackedYield;
224-
for (Value operand : adaptor.getOperands())
225-
unpackUnrealizedConversionCast(operand, unpackedYield);
226-
227-
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
201+
rewriter.replaceOpWithNewOp<scf::YieldOp>(
202+
op, flattenValues(adaptor.getOperands()));
228203
return success();
229204
}
230205
};
@@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
235210
public:
236211
using OpConversionPattern<ConditionOp>::OpConversionPattern;
237212
LogicalResult
238-
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
213+
matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
239214
ConversionPatternRewriter &rewriter) const override {
240-
SmallVector<Value> unpackedYield;
241-
for (Value operand : adaptor.getOperands())
242-
unpackUnrealizedConversionCast(operand, unpackedYield);
243-
244-
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
215+
rewriter.modifyOpInPlace(
216+
op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
245217
return success();
246218
}
247219
};

0 commit comments

Comments
 (0)