diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index eb444d665ff26..b1cde6ca5d2fc 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern { LogicalResult matchAndRewrite(CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Convert the original function results. + // Convert the original function results. Keep track of how many result + // types an original result type is converted into. + SmallVector numResultsReplacments; SmallVector convertedResults; - if (failed(typeConverter->convertTypes(callOp.getResultTypes(), - convertedResults))) - return failure(); - - // If this isn't a one-to-one type mapping, we don't know how to aggregate - // the results. - if (callOp->getNumResults() != convertedResults.size()) - return failure(); + size_t numFlattenedResults = 0; + for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) { + if (failed(typeConverter->convertTypes(type, convertedResults))) + return failure(); + numResultsReplacments.push_back(convertedResults.size() - + numFlattenedResults); + numFlattenedResults = convertedResults.size(); + } // Substitute with the new result types from the corresponding FuncType // conversion. - rewriter.replaceOpWithNewOp( - callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); + auto newCallOp = + rewriter.create(callOp.getLoc(), callOp.getCallee(), + convertedResults, adaptor.getOperands()); + SmallVector replacements; + size_t offset = 0; + for (int i = 0, e = callOp->getNumResults(); i < e; ++i) { + replacements.push_back( + newCallOp->getResults().slice(offset, numResultsReplacments[i])); + offset += numResultsReplacments[i]; + } + assert(offset == convertedResults.size() && + "expected that all converted results are used"); + rewriter.replaceOpWithMultiple(callOp, replacements); return success(); } }; diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index e5503ee892042..e05f444afa68f 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -379,15 +379,24 @@ builtin.module { // ----- -// expected-remark @below {{applyPartialConversion failed}} module { - func.func private @callee(%0 : f32) -> f32 - - func.func @caller( %arg: f32) { - // expected-error @below {{failed to legalize}} - %1 = func.call @callee(%arg) : (f32) -> f32 - return - } +// CHECK-LABEL: func.func private @callee() -> (f16, f16) +func.func private @callee() -> (f32, i24) + +// CHECK: func.func @caller() +func.func @caller() { + // f32 is converted to (f16, f16). + // i24 is converted to (). + // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16) + %0:2 = func.call @callee() : () -> (f32, i24) + + // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24 + // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32 + // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> () + // expected-remark @below{{'test.some_user' is not legalizable}} + "test.some_user"(%0#0, %0#1) : (f32, i24) -> () + "test.return"() : () -> () +} } // ----- diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 3df6cff3c0a60..bbd55938718fe 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1215,6 +1215,11 @@ struct TestTypeConverter : public TypeConverter { return success(); } + // Drop I24 types. + if (t.isInteger(24)) { + return success(); + } + // Otherwise, convert the type directly. results.push_back(t); return success();