Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-dialect-conversion-without-type-changes -verify-diagnostics %s | FileCheck %s

// Test that SSA values are properly replaced in dialect conversion even when
// types are not changed.

// CHECK-LABEL: @test1
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func.func @test1(%arg0: i32, %arg1 : i32) -> i32 {
%0 = "test.bgv_mul"(%arg0, %arg1) : (i32, i32) -> (i64)
%1 = "test.bgv_relin"(%0) : (i64) -> (i32)
%2 = "test.bgv_sub"(%1, %arg0) : (i32, i32) -> (i32)
%3 = "test.bgv_sub"(%2, %arg1) : (i32, i32) -> (i32)
func.return %3 : i32
}
30 changes: 30 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3067,4 +3067,34 @@ def TestOpOptionallyImplementingInterface
let arguments = (ins BoolAttr:$implementsInterface);
}

def TestOpBgvMul : TEST_Op<"bgv_mul", [Pure, Commutative]> {
let arguments = (ins I32:$lhs, I32:$rhs);
let results = (outs I64:$output);
}

def TestOpBgvRelin : TEST_Op<"bgv_relin", [Pure]> {
let arguments = (ins I64:$input);
let results = (outs I32:$output);
}

def TestOpBgvSub : TEST_Op<"bgv_sub", [Pure, SameOperandsAndResultType]> {
let arguments = (ins I32:$lhs, I32:$rhs);
let results = (outs I32:$output);
}

def TestOpOpenfheMul : TEST_Op<"openfhe_mul", [Pure, Commutative]> {
let arguments = (ins I8:$ctx, I32:$lhs, I32:$rhs);
let results = (outs I64:$output);
}

def TestOpOpenfheRelin : TEST_Op<"openfhe_relin", [Pure]> {
let arguments = (ins I8:$ctx, I64:$input);
let results = (outs I32:$output);
}

def TestOpOpenfheSub : TEST_Op<"openfhe_sub", [Pure]> {
let arguments = (ins I8:$ctx, I32:$lhs, I32:$rhs);
let results = (outs I32:$output);
}

#endif // TEST_OPS
47 changes: 47 additions & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,52 @@ struct TestSelectiveReplacementPatternDriver
};
} // namespace

namespace {
struct DialectConversionBugPattern
: public OpConversionPattern<TestOpDialectConversionBug1> {
using OpConversionPattern<TestOpDialectConversionBug1>::OpConversionPattern;

LogicalResult
matchAndRewrite(TestOpDialectConversionBug1 op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Just swap order of operands, no type changes
rewriter.replaceOpWithNewOp<TestOpDialectConversionBug2>(
op, adaptor.getRhs(), adaptor.getLhs());
return success();
}
};

struct TestDialectConversionWithoutResultTypeChanges
: public PassWrapper<TestDialectConversionWithoutResultTypeChanges,
OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestDialectConversionWithoutResultTypeChanges)

StringRef getArgument() const final {
return "test-dialect-conversion-without-type-changes";
}
StringRef getDescription() const final {
return "Test a bug in DialectConversion when op results don't change types";
}

void runOnOperation() override {
TypeConverter converter;
converter.addConversion([](Type t) { return t; });

ConversionTarget target(getContext());
target.addIllegalOp<TestOpDialectConversionBug1>();
target.addLegalOp<TestOpDialectConversionBug2>();

RewritePatternSet patterns(&getContext());
patterns.add<DialectConversionBugPattern>(converter, &getContext());

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//
Expand All @@ -1911,6 +1957,7 @@ void registerPatternsTestPass() {

PassRegistration<TestTypeConversionDriver>();
PassRegistration<TestTargetMaterializationWithNoUses>();
PassRegistration<TestDialectConversionWithoutResultTypeChanges>();

PassRegistration<TestRewriteDynamicOpDriver>();

Expand Down