diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td index c3e84bccc5dee..1bd8eb04714c5 100644 --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -127,6 +127,13 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> { "void", "getAsmName", (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";" >, + InterfaceMethod<[{ + Get a name to use when generating an alias for this type. + }], + "::mlir::OpAsmDialectInterface::AliasResult", "getAlias", + (ins "::llvm::raw_ostream&":$os), "", + "return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;" + >, ]; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cc578eae3ee36..1f22d4f37a813 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1163,14 +1163,13 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, OpAsmDialectInterface::AliasResult symbolInterfaceResult = OpAsmDialectInterface::AliasResult::NoAlias; - if constexpr (std::is_base_of_v) { - if (auto symbolInterface = dyn_cast(symbol)) { - symbolInterfaceResult = symbolInterface.getAlias(aliasOS); - if (symbolInterfaceResult != - OpAsmDialectInterface::AliasResult::NoAlias) { - nameBuffer = std::move(aliasBuffer); - assert(!nameBuffer.empty() && "expected valid alias name"); - } + using InterfaceT = std::conditional_t, + OpAsmAttrInterface, OpAsmTypeInterface>; + if (auto symbolInterface = dyn_cast(symbol)) { + symbolInterfaceResult = symbolInterface.getAlias(aliasOS); + if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) { + nameBuffer = std::move(aliasBuffer); + assert(!nameBuffer.empty() && "expected valid alias name"); } } diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir index 44a6e7afece03..086dc7da421c2 100644 --- a/mlir/test/IR/op-asm-interface.mlir +++ b/mlir/test/IR/op-asm-interface.mlir @@ -61,6 +61,16 @@ func.func @block_argument_name_from_op_asm_type_interface_asmprinter() { // ----- +// CHECK: !op_asm_type_interface_type = +!type = !test.op_asm_type_interface + +func.func @alias_from_op_asm_type_interface() { + %0 = "test.result_name_from_type"() : () -> !type + return +} + +// ----- + //===----------------------------------------------------------------------===// // Test OpAsmAttrInterface //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 6335701786ecc..c048f8b654ec2 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -399,7 +399,7 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> { } def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", - [DeclareTypeInterfaceMethods]> { + [DeclareTypeInterfaceMethods]> { let mnemonic = "op_asm_type_interface"; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 1ae7ac472d989..0c237440834ef 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -537,3 +537,9 @@ void TestTypeOpAsmTypeInterfaceType::getAsmName( OpAsmSetNameFn setNameFn) const { setNameFn("op_asm_type_interface"); } + +::mlir::OpAsmDialectInterface::AliasResult +TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const { + os << "op_asm_type_interface_type"; + return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; +}