Skip to content

Commit 0c57792

Browse files
Fix decompose call graph test
1 parent aa0196a commit 0c57792

File tree

5 files changed

+73
-218
lines changed

5 files changed

+73
-218
lines changed

mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,70 +23,10 @@
2323

2424
namespace mlir {
2525

26-
/// This class provides a hook that expands one Value into multiple Value's,
27-
/// with a TypeConverter-inspired callback registration mechanism.
28-
///
29-
/// For folks that are familiar with the dialect conversion framework /
30-
/// TypeConverter, this is effectively the inverse of a source/argument
31-
/// materialization. A target materialization is not what we want here because
32-
/// it always produces a single Value, but in this case the whole point is to
33-
/// decompose a Value into multiple Value's.
34-
///
35-
/// The reason we need this inverse is easily understood by looking at what we
36-
/// need to do for decomposing types for a return op. When converting a return
37-
/// op, the dialect conversion framework will give the list of converted
38-
/// operands, and will ensure that each converted operand, even if it expanded
39-
/// into multiple types, is materialized as a single result. We then need to
40-
/// undo that materialization to a single result, which we do with the
41-
/// decomposeValue hooks registered on this object.
42-
///
43-
/// TODO: Eventually, the type conversion infra should have this hook built-in.
44-
/// See
45-
/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
46-
class ValueDecomposer {
47-
public:
48-
/// This method tries to decompose a value of a certain type using provided
49-
/// decompose callback functions. If it is unable to do so, the original value
50-
/// is returned.
51-
void decomposeValue(OpBuilder &, Location, Type, Value,
52-
SmallVectorImpl<Value> &);
53-
54-
/// This method registers a callback function that will be called to decompose
55-
/// a value of a certain type into 0, 1, or multiple values.
56-
template <typename FnT, typename T = typename llvm::function_traits<
57-
std::decay_t<FnT>>::template arg_t<2>>
58-
void addDecomposeValueConversion(FnT &&callback) {
59-
decomposeValueConversions.emplace_back(
60-
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
61-
}
62-
63-
private:
64-
using DecomposeValueConversionCallFn =
65-
std::function<std::optional<LogicalResult>(
66-
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
67-
68-
/// Generate a wrapper for the given decompose value conversion callback.
69-
template <typename T, typename FnT>
70-
DecomposeValueConversionCallFn
71-
wrapDecomposeValueConversionCallback(FnT &&callback) {
72-
return
73-
[callback = std::forward<FnT>(callback)](
74-
OpBuilder &builder, Location loc, Type type, Value value,
75-
SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
76-
if (T derivedType = dyn_cast<T>(type))
77-
return callback(builder, loc, derivedType, value, newValues);
78-
return std::nullopt;
79-
};
80-
}
81-
82-
SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
83-
};
84-
8526
/// Populates the patterns needed to drive the conversion process for
8627
/// decomposing call graph types with the given `ValueDecomposer`.
8728
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
8829
const TypeConverter &typeConverter,
89-
ValueDecomposer &decomposer,
9030
RewritePatternSet &patterns);
9131

9232
} // namespace mlir

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 30 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,53 +13,15 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// ValueDecomposer
18-
//===----------------------------------------------------------------------===//
19-
20-
void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
21-
Type type, Value value,
22-
SmallVectorImpl<Value> &results) {
23-
for (auto &conversion : decomposeValueConversions)
24-
if (conversion(builder, loc, type, value, results))
25-
return;
26-
results.push_back(value);
27-
}
28-
29-
//===----------------------------------------------------------------------===//
30-
// DecomposeCallGraphTypesOpConversionPattern
31-
//===----------------------------------------------------------------------===//
32-
33-
namespace {
34-
/// Base OpConversionPattern class to make a ValueDecomposer available to
35-
/// inherited patterns.
36-
template <typename SourceOp>
37-
class DecomposeCallGraphTypesOpConversionPattern
38-
: public OpConversionPattern<SourceOp> {
39-
public:
40-
DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
41-
MLIRContext *context,
42-
ValueDecomposer &decomposer,
43-
PatternBenefit benefit = 1)
44-
: OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45-
decomposer(decomposer) {}
46-
47-
protected:
48-
ValueDecomposer &decomposer;
49-
};
50-
} // namespace
51-
5216
//===----------------------------------------------------------------------===//
5317
// DecomposeCallGraphTypesForFuncArgs
5418
//===----------------------------------------------------------------------===//
5519

5620
namespace {
57-
/// Expand function arguments according to the provided TypeConverter and
58-
/// ValueDecomposer.
21+
/// Expand function arguments according to the provided TypeConverter.
5922
struct DecomposeCallGraphTypesForFuncArgs
60-
: public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61-
using DecomposeCallGraphTypesOpConversionPattern::
62-
DecomposeCallGraphTypesOpConversionPattern;
23+
: public OpConversionPattern<func::FuncOp> {
24+
using OpConversionPattern::OpConversionPattern;
6325

6426
LogicalResult
6527
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +62,17 @@ struct DecomposeCallGraphTypesForFuncArgs
10062
//===----------------------------------------------------------------------===//
10163

10264
namespace {
103-
/// Expand return operands according to the provided TypeConverter and
104-
/// ValueDecomposer.
65+
/// Expand return operands according to the provided TypeConverter.
10566
struct DecomposeCallGraphTypesForReturnOp
106-
: public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107-
using DecomposeCallGraphTypesOpConversionPattern::
108-
DecomposeCallGraphTypesOpConversionPattern;
67+
: public OpConversionPattern<ReturnOp> {
68+
using OpConversionPattern::OpConversionPattern;
69+
10970
LogicalResult
110-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
11172
ConversionPatternRewriter &rewriter) const final {
11273
SmallVector<Value, 2> newOperands;
113-
for (Value operand : adaptor.getOperands())
114-
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115-
operand, newOperands);
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11676
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11777
return success();
11878
}
@@ -124,74 +84,53 @@ struct DecomposeCallGraphTypesForReturnOp
12484
//===----------------------------------------------------------------------===//
12585

12686
namespace {
127-
/// Expand call op operands and results according to the provided TypeConverter
128-
/// and ValueDecomposer.
87+
/// Expand call op operands and results according to the provided TypeConverter.
12988
struct DecomposeCallGraphTypesForCallOp
130-
: public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131-
using DecomposeCallGraphTypesOpConversionPattern::
132-
DecomposeCallGraphTypesOpConversionPattern;
89+
: public OpConversionPattern<CallOp> {
90+
using OpConversionPattern::OpConversionPattern;
13391

13492
LogicalResult
135-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
93+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13694
ConversionPatternRewriter &rewriter) const final {
13795

13896
// Create the operands list of the new `CallOp`.
13997
SmallVector<Value, 2> newOperands;
140-
for (Value operand : adaptor.getOperands())
141-
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142-
operand, newOperands);
143-
144-
// Create the new result types for the new `CallOp` and track the indices in
145-
// the new call op's results that correspond to the old call op's results.
146-
//
147-
// expandedResultIndices[i] = "list of new result indices that old result i
148-
// expanded to".
98+
for (ValueRange operand : adaptor.getOperands())
99+
llvm::append_range(newOperands, operand);
100+
101+
// Create the new result types for the new `CallOp` and track the number of
102+
// replacement types for each original op result.
149103
SmallVector<Type, 2> newResultTypes;
150-
SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
104+
SmallVector<unsigned> expandedResultSizes;
151105
for (Type resultType : op.getResultTypes()) {
152106
unsigned oldSize = newResultTypes.size();
153107
if (failed(typeConverter->convertType(resultType, newResultTypes)))
154108
return failure();
155-
auto &resultMapping = expandedResultIndices.emplace_back();
156-
for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157-
resultMapping.push_back(i);
109+
expandedResultSizes.push_back(newResultTypes.size() - oldSize);
158110
}
159111

160112
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161113
newResultTypes, newOperands);
162114

163-
// Build a replacement value for each result to replace its uses. If a
164-
// result has multiple mapping values, it needs to be materialized as a
165-
// single value.
166-
SmallVector<Value, 2> replacedValues;
115+
// Build a replacement value for each result to replace its uses.
116+
SmallVector<ValueRange> replacedValues;
167117
replacedValues.reserve(op.getNumResults());
118+
unsigned startIdx = 0;
168119
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169-
auto decomposedValues = llvm::to_vector<6>(
170-
llvm::map_range(expandedResultIndices[i],
171-
[&](unsigned i) { return newCallOp.getResult(i); }));
172-
if (decomposedValues.empty()) {
173-
// No replacement is required.
174-
replacedValues.push_back(nullptr);
175-
} else if (decomposedValues.size() == 1) {
176-
replacedValues.push_back(decomposedValues.front());
177-
} else {
178-
// Materialize a single Value to replace the original Value.
179-
Value materialized = getTypeConverter()->materializeArgumentConversion(
180-
rewriter, op.getLoc(), op.getType(i), decomposedValues);
181-
replacedValues.push_back(materialized);
182-
}
120+
ValueRange repl = newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
121+
replacedValues.push_back(repl);
122+
startIdx += expandedResultSizes[i];
183123
}
184-
rewriter.replaceOp(op, replacedValues);
124+
rewriter.replaceOpWithMultiple(op, replacedValues);
185125
return success();
186126
}
187127
};
188128
} // namespace
189129

190130
void mlir::populateDecomposeCallGraphTypesPatterns(
191131
MLIRContext *context, const TypeConverter &typeConverter,
192-
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
132+
RewritePatternSet &patterns) {
193133
patterns
194134
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195-
DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196-
decomposer);
135+
DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197136
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12771277
}
12781278

12791279
// Try to find a mapped value with the desired type.
1280+
if (legalTypes.empty()) {
1281+
remapped.push_back({});
1282+
continue;
1283+
}
1284+
12801285
SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
12811286
if (!mat.empty()) {
12821287
// Mapped value has the correct type or there is an existing
@@ -2577,34 +2582,29 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
25772582
assert(!op.use_empty() &&
25782583
"expected that dead materializations have already been DCE'd");
25792584
Operation::operand_range inputOperands = op.getOperands();
2580-
Type outputType = op.getResultTypes()[0];
25812585

25822586
// Try to materialize the conversion.
25832587
if (const TypeConverter *converter = rewrite->getConverter()) {
25842588
rewriter.setInsertionPoint(op);
2585-
Value newMaterialization;
2589+
SmallVector<Value> newMaterialization;
25862590
switch (rewrite->getMaterializationKind()) {
25872591
case MaterializationKind::Argument:
2588-
// Try to materialize an argument conversion.
2589-
newMaterialization = converter->materializeArgumentConversion(
2590-
rewriter, op->getLoc(), outputType, inputOperands);
2591-
if (newMaterialization)
2592-
break;
2593-
// If an argument materialization failed, fallback to trying a target
2594-
// materialization.
2595-
[[fallthrough]];
2592+
llvm_unreachable("argument materializations have been removed");
25962593
case MaterializationKind::Target:
25972594
newMaterialization = converter->materializeTargetConversion(
2598-
rewriter, op->getLoc(), outputType, inputOperands,
2595+
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
25992596
rewrite->getOriginalType());
26002597
break;
26012598
case MaterializationKind::Source:
2602-
newMaterialization = converter->materializeSourceConversion(
2603-
rewriter, op->getLoc(), outputType, inputOperands);
2599+
assert(op.getNumResults() == 1 && "*:N source materializations are not supported");
2600+
Value sourceMat = converter->materializeSourceConversion(
2601+
rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2602+
if (sourceMat)
2603+
newMaterialization.push_back(sourceMat);
26042604
break;
26052605
}
2606-
if (newMaterialization) {
2607-
assert(newMaterialization.getType() == outputType &&
2606+
if (!newMaterialization.empty()) {
2607+
assert(TypeRange(newMaterialization) == op.getResultTypes() &&
26082608
"materialization callback produced value of incorrect type");
26092609
rewriter.replaceOp(op, newMaterialization);
26102610
return success();
@@ -2614,8 +2614,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
26142614
InFlightDiagnostic diag = op->emitError()
26152615
<< "failed to legalize unresolved materialization "
26162616
"from ("
2617-
<< inputOperands.getTypes() << ") to " << outputType
2618-
<< " that remained live after conversion";
2617+
<< inputOperands.getTypes() << ") to (" << op.getResultTypes()
2618+
<< ") that remained live after conversion";
26192619
diag.attachNote(op->getUsers().begin()->getLoc())
26202620
<< "see existing live user here: " << *op->getUsers().begin();
26212621
return failure();

0 commit comments

Comments
 (0)