Skip to content

Conversation

@matthias-springer
Copy link
Member

This commit adds support for 1:N result type conversions for func.call ops. In that case, argument materializations to the original result type should be inserted (via replaceOpWithMultiple).

This commit is in preparation of merging the 1:1 and 1:N conversion drivers.

@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2024

@llvm/pr-subscribers-mlir-func

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for 1:N result type conversions for func.call ops. In that case, argument materializations to the original result type should be inserted (via replaceOpWithMultiple).

This commit is in preparation of merging the 1:1 and 1:N conversion drivers.


Full diff: https://github.com/llvm/llvm-project/pull/117413.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+24-11)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+12)
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..b1cde6ca5d2fca 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
   LogicalResult
   matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Convert the original function results.
+    // Convert the original function results. Keep track of how many result
+    // types an original result type is converted into.
+    SmallVector<size_t> numResultsReplacments;
     SmallVector<Type, 1> convertedResults;
-    if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
-                                           convertedResults)))
-      return failure();
-
-    // If this isn't a one-to-one type mapping, we don't know how to aggregate
-    // the results.
-    if (callOp->getNumResults() != convertedResults.size())
-      return failure();
+    size_t numFlattenedResults = 0;
+    for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
+      if (failed(typeConverter->convertTypes(type, convertedResults)))
+        return failure();
+      numResultsReplacments.push_back(convertedResults.size() -
+                                      numFlattenedResults);
+      numFlattenedResults = convertedResults.size();
+    }
 
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
-    rewriter.replaceOpWithNewOp<CallOp>(
-        callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+    auto newCallOp =
+        rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
+                                convertedResults, adaptor.getOperands());
+    SmallVector<ValueRange> replacements;
+    size_t offset = 0;
+    for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
+      replacements.push_back(
+          newCallOp->getResults().slice(offset, numResultsReplacments[i]));
+      offset += numResultsReplacments[i];
+    }
+    assert(offset == convertedResults.size() &&
+           "expected that all converted results are used");
+    rewriter.replaceOpWithMultiple(callOp, replacements);
     return success();
   }
 };
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..7c6e3c5c3a6c53 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -463,3 +463,25 @@ func.func @circular_mapping() {
   %0 = "test.erase_op"() : () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
+
+// -----
+
+module {
+// CHECK-LABEL: func.func private @foo() -> (i23, i23)
+func.func private @foo() -> (i22, i24)
+
+// CHECK: func.func @bar()
+func.func @bar() {
+  // i22 is converted to (i23, i23).
+  // i24 is converted to ().
+  // CHECK: %[[call:.*]]:2 = call @foo() : () -> (i23, i23)
+  %0:2 = func.call @foo() : () -> (i22, i24)
+
+  // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
+  // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (i23, i23) -> i22
+  // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (i22, i24) -> ()
+  // expected-remark @below{{'test.some_user' is not legalizable}}
+  "test.some_user"(%0#0, %0#1) : (i22, i24) -> ()
+  "test.return"() : () -> ()
+}
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..912173f391086e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1215,6 +1215,18 @@ struct TestTypeConverter : public TypeConverter {
       return success();
     }
 
+    // Convert I22 to multiple I23.
+    if (t.isInteger(22)) {
+      results.push_back(IntegerType::get(t.getContext(), 23));
+      results.push_back(IntegerType::get(t.getContext(), 23));
+      return success();
+    }
+
+    // Drop I24 types.
+    if (t.isInteger(24)) {
+      return success();
+    }
+
     // Otherwise, convert the type directly.
     results.push_back(t);
     return success();

@matthias-springer matthias-springer force-pushed the users/matthias-springer/1_n_call_op branch from 2b8a426 to 85cec06 Compare November 23, 2024 05:18
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@matthias-springer matthias-springer merged commit 08e6566 into main Nov 23, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/1_n_call_op branch November 23, 2024 11:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants