Skip to content

Commit 76ccdee

Browse files
Fix decompose call graph test
1 parent 204d89f commit 76ccdee

File tree

5 files changed

+45
-134
lines changed

5 files changed

+45
-134
lines changed

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,8 @@ static void restoreByValRefArgumentType(
284284
for (const auto &[arg, oldArg, byValRefAttr] :
285285
llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286286
// Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287-
if (!byValRefAttr) {
288-
llvm::errs() << "NO ATTR!\n";
287+
if (!byValRefAttr)
289288
continue;
290-
}
291289

292290
// Insert load to retrieve the actual argument passed by value/reference.
293291
assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 17 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// Helper functions
18-
//===----------------------------------------------------------------------===//
19-
20-
/// If the given value can be decomposed with the type converter, decompose it.
21-
/// Otherwise, return the given value.
22-
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23-
// This function will disappear when the 1:1 and 1:N drivers are merged.
24-
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25-
Value value,
26-
const TypeConverter *converter) {
27-
// Try to convert the given value's type. If that fails, just return the
28-
// given value.
29-
SmallVector<Type> convertedTypes;
30-
if (failed(converter->convertType(value.getType(), convertedTypes)))
31-
return {value};
32-
if (convertedTypes.empty())
33-
return {};
34-
35-
// If the given value's type is already legal, just return the given value.
36-
TypeRange convertedTypeRange(convertedTypes);
37-
if (convertedTypeRange == TypeRange(value.getType()))
38-
return {value};
39-
40-
// Try to materialize a target conversion. If the materialization did not
41-
// produce values of the requested type, the materialization failed. Just
42-
// return the given value in that case.
43-
SmallVector<Value> result = converter->materializeTargetConversion(
44-
builder, loc, convertedTypeRange, value);
45-
if (result.empty())
46-
return {value};
47-
return result;
48-
}
49-
5016
//===----------------------------------------------------------------------===//
5117
// DecomposeCallGraphTypesForFuncArgs
5218
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
10268
using OpConversionPattern::OpConversionPattern;
10369

10470
LogicalResult
105-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
10672
ConversionPatternRewriter &rewriter) const final {
10773
SmallVector<Value, 2> newOperands;
108-
for (Value operand : adaptor.getOperands()) {
109-
// TODO: We can directly take the values from the adaptor once this is a
110-
// 1:N conversion pattern.
111-
llvm::append_range(newOperands,
112-
decomposeValue(rewriter, operand.getLoc(), operand,
113-
getTypeConverter()));
114-
}
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11576
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11677
return success();
11778
}
@@ -128,60 +89,38 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
12889
using OpConversionPattern::OpConversionPattern;
12990

13091
LogicalResult
131-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
92+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13293
ConversionPatternRewriter &rewriter) const final {
13394

13495
// Create the operands list of the new `CallOp`.
13596
SmallVector<Value, 2> newOperands;
136-
for (Value operand : adaptor.getOperands()) {
137-
// TODO: We can directly take the values from the adaptor once this is a
138-
// 1:N conversion pattern.
139-
llvm::append_range(newOperands,
140-
decomposeValue(rewriter, operand.getLoc(), operand,
141-
getTypeConverter()));
142-
}
97+
for (ValueRange operand : adaptor.getOperands())
98+
llvm::append_range(newOperands, operand);
14399

144-
// Create the new result types for the new `CallOp` and track the indices in
145-
// the new call op's results that correspond to the old call op's results.
146-
//
147-
// expandedResultIndices[i] = "list of new result indices that old result i
148-
// expanded to".
100+
// Create the new result types for the new `CallOp` and track the number of
101+
// replacement types for each original op result.
149102
SmallVector<Type, 2> newResultTypes;
150-
SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
103+
SmallVector<unsigned> expandedResultSizes;
151104
for (Type resultType : op.getResultTypes()) {
152105
unsigned oldSize = newResultTypes.size();
153106
if (failed(typeConverter->convertType(resultType, newResultTypes)))
154107
return failure();
155-
auto &resultMapping = expandedResultIndices.emplace_back();
156-
for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157-
resultMapping.push_back(i);
108+
expandedResultSizes.push_back(newResultTypes.size() - oldSize);
158109
}
159110

160111
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161112
newResultTypes, newOperands);
162113

163-
// Build a replacement value for each result to replace its uses. If a
164-
// result has multiple mapping values, it needs to be materialized as a
165-
// single value.
166-
SmallVector<Value, 2> replacedValues;
114+
// Build a replacement value for each result to replace its uses.
115+
SmallVector<ValueRange> replacedValues;
167116
replacedValues.reserve(op.getNumResults());
117+
unsigned startIdx = 0;
168118
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169-
auto decomposedValues = llvm::to_vector<6>(
170-
llvm::map_range(expandedResultIndices[i],
171-
[&](unsigned i) { return newCallOp.getResult(i); }));
172-
if (decomposedValues.empty()) {
173-
// No replacement is required.
174-
replacedValues.push_back(nullptr);
175-
} else if (decomposedValues.size() == 1) {
176-
replacedValues.push_back(decomposedValues.front());
177-
} else {
178-
// Materialize a single Value to replace the original Value.
179-
Value materialized = getTypeConverter()->materializeArgumentConversion(
180-
rewriter, op.getLoc(), op.getType(i), decomposedValues);
181-
replacedValues.push_back(materialized);
182-
}
119+
ValueRange repl = newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
120+
replacedValues.push_back(repl);
121+
startIdx += expandedResultSizes[i];
183122
}
184-
rewriter.replaceOp(op, replacedValues);
123+
rewriter.replaceOpWithMultiple(op, replacedValues);
185124
return success();
186125
}
187126
};

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12771277
}
12781278

12791279
// Try to find a mapped value with the desired type.
1280+
if (legalTypes.empty()) {
1281+
remapped.push_back({});
1282+
continue;
1283+
}
1284+
12801285
SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
12811286
if (!mat.empty()) {
12821287
// Mapped value has the correct type or there is an existing
@@ -2577,45 +2582,40 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
25772582
assert(!op.use_empty() &&
25782583
"expected that dead materializations have already been DCE'd");
25792584
Operation::operand_range inputOperands = op.getOperands();
2580-
Type outputType = op.getResultTypes()[0];
25812585

25822586
// Try to materialize the conversion.
25832587
if (const TypeConverter *converter = rewrite->getConverter()) {
25842588
rewriter.setInsertionPoint(op);
2585-
Value newMaterialization;
2589+
SmallVector<Value> newMaterialization;
25862590
switch (rewrite->getMaterializationKind()) {
25872591
case MaterializationKind::Argument:
2588-
// Try to materialize an argument conversion.
2589-
newMaterialization = converter->materializeArgumentConversion(
2590-
rewriter, op->getLoc(), outputType, inputOperands);
2591-
if (newMaterialization)
2592-
break;
2593-
// If an argument materialization failed, fallback to trying a target
2594-
// materialization.
2595-
[[fallthrough]];
2592+
llvm_unreachable("argument materializations have been removed");
25962593
case MaterializationKind::Target:
25972594
newMaterialization = converter->materializeTargetConversion(
2598-
rewriter, op->getLoc(), outputType, inputOperands,
2595+
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
25992596
rewrite->getOriginalType());
26002597
break;
26012598
case MaterializationKind::Source:
2602-
newMaterialization = converter->materializeSourceConversion(
2603-
rewriter, op->getLoc(), outputType, inputOperands);
2599+
assert(op.getNumResults() == 1 && "*:N source materializations are not supported");
2600+
Value sourceMat = converter->materializeSourceConversion(
2601+
rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2602+
if (sourceMat)
2603+
newMaterialization.push_back(sourceMat);
26042604
break;
26052605
}
2606-
if (newMaterialization) {
2607-
assert(newMaterialization.getType() == outputType &&
2606+
if (!newMaterialization.empty()) {
2607+
assert(TypeRange(newMaterialization) == op.getResultTypes() &&
26082608
"materialization callback produced value of incorrect type");
26092609
rewriter.replaceOp(op, newMaterialization);
26102610
return success();
26112611
}
26122612
}
26132613

2614-
InFlightDiagnostic diag =
2615-
op->emitError() << "failed to legalize unresolved materialization "
2616-
"from ("
2617-
<< inputOperands.getTypes() << ") to (" << outputType
2618-
<< ") that remained live after conversion";
2614+
InFlightDiagnostic diag = op->emitError()
2615+
<< "failed to legalize unresolved materialization "
2616+
"from ("
2617+
<< inputOperands.getTypes() << ") to (" << op.getResultTypes()
2618+
<< ") that remained live after conversion";
26192619
diag.attachNote(op->getUsers().begin()->getLoc())
26202620
<< "see existing live user here: " << *op->getUsers().begin();
26212621
return failure();

mlir/test/Transforms/decompose-call-graph-types.mlir

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
// CHECK-LABEL: func @identity(
1010
// CHECK-SAME: %[[ARG0:.*]]: i1,
1111
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
12-
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
13-
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
14-
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
15-
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
12+
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32
1613
// CHECK-12N-LABEL: func @identity(
1714
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
1815
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
5653
// CHECK-LABEL: func @mixed_recursive_decomposition(
5754
// CHECK-SAME: %[[ARG0:.*]]: i1,
5855
// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
59-
// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
60-
// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
61-
// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
62-
// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
63-
// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
64-
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
65-
// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
66-
// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
67-
// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
68-
// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
69-
// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
70-
// CHECK: return %[[V7]], %[[V10]] : i1, i2
56+
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2
7157
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
7258
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
7359
// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
@@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
8773
// CHECK-LABEL: func @caller(
8874
// CHECK-SAME: %[[ARG0:.*]]: i1,
8975
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
90-
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
91-
// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
92-
// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
93-
// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
94-
// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
95-
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
96-
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
97-
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
76+
// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
77+
// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32
9878
// CHECK-12N-LABEL: func @caller(
9979
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
10080
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
190170
// CHECK-SAME: %[[I4:.*]]: i4,
191171
// CHECK-SAME: %[[I5:.*]]: i5,
192172
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
193-
// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
194-
// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
195-
// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
196-
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
197-
// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
198-
// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
199-
// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
200-
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
173+
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
174+
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
201175
// CHECK-12N-LABEL: func @caller(
202176
// CHECK-12N-SAME: %[[I1:.*]]: i1,
203177
// CHECK-12N-SAME: %[[I2:.*]]: i2,

mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
139139
tupleType.getFlattenedTypes(types);
140140
return success();
141141
});
142-
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
142+
typeConverter.addSourceMaterialization(buildMakeTupleOp);
143143
typeConverter.addTargetMaterialization(buildDecomposeTuple);
144144

145145
populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);

0 commit comments

Comments
 (0)