diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0453110..bbfa30815bd4a 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Returns true if this symbol has nested visibility.", "bool", "isNested", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Nested; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Private; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private; }] >, InterfaceMethod<"Returns true if this symbol has public visibility.", "bool", "isPublic", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<"Sets the visibility of this symbol.", @@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Sets the visibility of this symbol to be nested.", "void", "setNested", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Nested); + $_op.setVisibility(mlir::SymbolTable::Visibility::Nested); }] >, InterfaceMethod<"Sets the visibility of this symbol to be private.", "void", "setPrivate", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Private); + $_op.setVisibility(mlir::SymbolTable::Visibility::Private); }] >, InterfaceMethod<"Sets the visibility of this symbol to be public.", "void", "setPublic", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Public); + $_op.setVisibility(mlir::SymbolTable::Visibility::Public); }] >, InterfaceMethod<[{ @@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { // By default, base this on the visibility alone. A symbol can be // discarded as long as it is not public. Only public symbols may be // visible from outside of the IR. - return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<[{ diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index f79e2cfbcb259..53055fea215b7 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -17,6 +17,32 @@ using namespace mlir; using namespace test; +//===----------------------------------------------------------------------===// +// OverridenSymbolVisibilityOp +//===----------------------------------------------------------------------===// + +SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() { + return SymbolTable::Visibility::Private; +} + +static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) { + switch (visibility) { + case SymbolTable::Visibility::Private: + return "private"; + case SymbolTable::Visibility::Nested: + return "nested"; + case SymbolTable::Visibility::Public: + return "public"; + } +} + +void OverriddenSymbolVisibilityOp::setVisibility( + SymbolTable::Visibility visibility) { + + emitOpError("cannot change visibility of symbol to ") + << getVisibilityString(visibility); +} + //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a7c6cd60a0ee4..927c98225bf2f 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -119,6 +119,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr:$sym_visibility); } +def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ + DeclareOpInterfaceMethods, +]> { + let summary = "operation overridden symbol visibility accessors"; + let arguments = (ins StrAttr:$sym_name); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index cfc3fe0cb1c5b..4b3545bce1952 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { }); } +TEST(SymbolOpInterface, Visibility) { + DialectRegistry registry; + ::test::registerTestDialect(registry); + MLIRContext context(registry); + + constexpr static StringLiteral kInput = R"MLIR( + "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> () + )MLIR"; + OwningOpRef module = parseSourceString(kInput, &context); + auto symOp = cast(module->getBody()->front()); + + ASSERT_TRUE(symOp.isPrivate()); + ASSERT_FALSE(symOp.isPublic()); + ASSERT_FALSE(symOp.isNested()); + ASSERT_TRUE(symOp.canDiscardOnUseEmpty()); + + std::string diagStr; + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { diagStr += diag.str(); }); + + std::string expectedDiag; + symOp.setPublic(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to public"; + symOp.setNested(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to nested"; + symOp.setPrivate(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to private"; + + ASSERT_EQ(diagStr, expectedDiag); +} + } // namespace