-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Func] Support 1:N result type conversions in func.call conversion
#117413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-func @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for 1:N result type conversions for This commit is in preparation of merging the 1:1 and 1:N conversion drivers. Full diff: https://github.com/llvm/llvm-project/pull/117413.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..b1cde6ca5d2fca 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
LogicalResult
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Convert the original function results.
+ // Convert the original function results. Keep track of how many result
+ // types an original result type is converted into.
+ SmallVector<size_t> numResultsReplacments;
SmallVector<Type, 1> convertedResults;
- if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
- convertedResults)))
- return failure();
-
- // If this isn't a one-to-one type mapping, we don't know how to aggregate
- // the results.
- if (callOp->getNumResults() != convertedResults.size())
- return failure();
+ size_t numFlattenedResults = 0;
+ for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
+ if (failed(typeConverter->convertTypes(type, convertedResults)))
+ return failure();
+ numResultsReplacments.push_back(convertedResults.size() -
+ numFlattenedResults);
+ numFlattenedResults = convertedResults.size();
+ }
// Substitute with the new result types from the corresponding FuncType
// conversion.
- rewriter.replaceOpWithNewOp<CallOp>(
- callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+ auto newCallOp =
+ rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
+ convertedResults, adaptor.getOperands());
+ SmallVector<ValueRange> replacements;
+ size_t offset = 0;
+ for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
+ replacements.push_back(
+ newCallOp->getResults().slice(offset, numResultsReplacments[i]));
+ offset += numResultsReplacments[i];
+ }
+ assert(offset == convertedResults.size() &&
+ "expected that all converted results are used");
+ rewriter.replaceOpWithMultiple(callOp, replacements);
return success();
}
};
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..7c6e3c5c3a6c53 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -463,3 +463,25 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+module {
+// CHECK-LABEL: func.func private @foo() -> (i23, i23)
+func.func private @foo() -> (i22, i24)
+
+// CHECK: func.func @bar()
+func.func @bar() {
+ // i22 is converted to (i23, i23).
+ // i24 is converted to ().
+ // CHECK: %[[call:.*]]:2 = call @foo() : () -> (i23, i23)
+ %0:2 = func.call @foo() : () -> (i22, i24)
+
+ // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
+ // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (i23, i23) -> i22
+ // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (i22, i24) -> ()
+ // expected-remark @below{{'test.some_user' is not legalizable}}
+ "test.some_user"(%0#0, %0#1) : (i22, i24) -> ()
+ "test.return"() : () -> ()
+}
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..912173f391086e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1215,6 +1215,18 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
+ // Convert I22 to multiple I23.
+ if (t.isInteger(22)) {
+ results.push_back(IntegerType::get(t.getContext(), 23));
+ results.push_back(IntegerType::get(t.getContext(), 23));
+ return success();
+ }
+
+ // Drop I24 types.
+ if (t.isInteger(24)) {
+ return success();
+ }
+
// Otherwise, convert the type directly.
results.push_back(t);
return success();
|
2b8a426 to
85cec06
Compare
zero9178
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This commit adds support for 1:N result type conversions for
func.callops. In that case, argument materializations to the original result type should be inserted (viareplaceOpWithMultiple).This commit is in preparation of merging the 1:1 and 1:N conversion drivers.