diff --git a/mlir/test/Transforms/test-dialect-conversion-without-type-changes.mlir b/mlir/test/Transforms/test-dialect-conversion-without-type-changes.mlir new file mode 100644 index 0000000000000..39e185401821d --- /dev/null +++ b/mlir/test/Transforms/test-dialect-conversion-without-type-changes.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-dialect-conversion-without-type-changes -verify-diagnostics %s | FileCheck %s + +// Test that SSA values are properly replaced in dialect conversion even when +// types are not changed. + +// CHECK-LABEL: @test1 +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func.func @test1(%arg0: i32, %arg1 : i32) -> i32 { + %0 = "test.bgv_mul"(%arg0, %arg1) : (i32, i32) -> (i64) + %1 = "test.bgv_relin"(%0) : (i64) -> (i32) + %2 = "test.bgv_sub"(%1, %arg0) : (i32, i32) -> (i32) + %3 = "test.bgv_sub"(%2, %arg1) : (i32, i32) -> (i32) + func.return %3 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index e6c3601d08dad..412c7297903ba 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3067,4 +3067,34 @@ def TestOpOptionallyImplementingInterface let arguments = (ins BoolAttr:$implementsInterface); } +def TestOpBgvMul : TEST_Op<"bgv_mul", [Pure, Commutative]> { + let arguments = (ins I32:$lhs, I32:$rhs); + let results = (outs I64:$output); +} + +def TestOpBgvRelin : TEST_Op<"bgv_relin", [Pure]> { + let arguments = (ins I64:$input); + let results = (outs I32:$output); +} + +def TestOpBgvSub : TEST_Op<"bgv_sub", [Pure, SameOperandsAndResultType]> { + let arguments = (ins I32:$lhs, I32:$rhs); + let results = (outs I32:$output); +} + +def TestOpOpenfheMul : TEST_Op<"openfhe_mul", [Pure, Commutative]> { + let arguments = (ins I8:$ctx, I32:$lhs, I32:$rhs); + let results = (outs I64:$output); +} + +def TestOpOpenfheRelin : TEST_Op<"openfhe_relin", [Pure]> { + let arguments = (ins I8:$ctx, I64:$input); + let results = (outs I32:$output); +} + +def TestOpOpenfheSub : TEST_Op<"openfhe_sub", [Pure]> { + let arguments = (ins I8:$ctx, I32:$lhs, I32:$rhs); + let results = (outs I32:$output); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 2da184bc3d85b..e95f3c4332270 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1887,6 +1887,52 @@ struct TestSelectiveReplacementPatternDriver }; } // namespace +namespace { +struct DialectConversionBugPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestOpDialectConversionBug1 op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Just swap order of operands, no type changes + rewriter.replaceOpWithNewOp( + op, adaptor.getRhs(), adaptor.getLhs()); + return success(); + } +}; + +struct TestDialectConversionWithoutResultTypeChanges + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestDialectConversionWithoutResultTypeChanges) + + StringRef getArgument() const final { + return "test-dialect-conversion-without-type-changes"; + } + StringRef getDescription() const final { + return "Test a bug in DialectConversion when op results don't change types"; + } + + void runOnOperation() override { + TypeConverter converter; + converter.addConversion([](Type t) { return t; }); + + ConversionTarget target(getContext()); + target.addIllegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.add(converter, &getContext()); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// @@ -1911,6 +1957,7 @@ void registerPatternsTestPass() { PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration();