Skip to content

Commit f5d3cf4

Browse files
[mlir][TableGen] Emit interface traits after all interfaces (#147699)
Interface traits may provide default implementation of methods. When this happens, the implementation may rely on another interface that is not yet defined meaning that one gets "incomplete type" error during C++ compilation. In pseudo-code, the problem is the following: ``` InterfaceA has methodB() { return InterfaceB(); } InterfaceB defined later // What's generated is: class InterfaceA { ... } class InterfaceATrait { // error: InterfaceB is an incomplete type InterfaceB methodB() { return InterfaceB(); } } class InterfaceB { ... } // defined here ``` The two more "advanced" cases are: * Cyclic dependency (A requires B and B requires A) * Type-traited usage of an incomplete type (e.g. `FailureOr<InterfaceB>`) It seems reasonable to emit interface traits *after* all of the interfaces have been defined to avoid the problem altogether. As a drive by, make forward declarations of the interfaces early so that user code does not need to forward declare.
1 parent 5365f8b commit f5d3cf4

File tree

3 files changed

+79
-18
lines changed

3 files changed

+79
-18
lines changed

mlir/test/lib/Dialect/Test/TestInterfaces.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,32 @@ def TestOptionallyImplementedTypeInterface
174174
}];
175175
}
176176

177+
// Dummy type interface "A" that requires type interface "B" to be complete.
178+
def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> {
179+
let cppNamespace = "::mlir";
180+
let methods = [
181+
InterfaceMethod<"",
182+
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>",
183+
/*methodName=*/"returnB",
184+
(ins),
185+
/*methodBody=*/"",
186+
/*defaultImpl=*/"return mlir::failure();"
187+
>,
188+
];
189+
}
190+
191+
// Dummy type interface "B" that requires type interface "A" to be complete.
192+
def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> {
193+
let cppNamespace = "::mlir";
194+
let methods = [
195+
InterfaceMethod<"",
196+
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>",
197+
/*methodName=*/"returnA",
198+
(ins),
199+
/*methodBody=*/"",
200+
/*defaultImpl=*/"return mlir::failure();"
201+
>,
202+
];
203+
}
204+
177205
#endif // MLIR_TEST_DIALECT_TEST_INTERFACES

mlir/test/mlir-tblgen/op-interface.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
3131
// DECL-NEXT: return (*this).someOtherMethod();
3232
// DECL-NEXT: }
3333

34-
// DECL: struct ExtraShardDeclsInterfaceTrait
35-
// DECL: bool sharedMethodDeclaration() {
36-
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
37-
// DECL-NEXT: }
38-
3934
def TestInheritanceMultiBaseInterface : OpInterface<"TestInheritanceMultiBaseInterface"> {
4035
let methods = [
4136
InterfaceMethod<
@@ -71,7 +66,7 @@ def TestInheritanceMiddleBaseInterface
7166
def TestInheritanceZDerivedInterface
7267
: OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>;
7368

74-
// DECL: class TestInheritanceZDerivedInterface
69+
// DECL: struct TestInheritanceZDerivedInterfaceInterfaceTraits
7570
// DECL: struct Concept {
7671
// DECL: const TestInheritanceMultiBaseInterface::Concept *implTestInheritanceMultiBaseInterface = nullptr;
7772
// DECL-NOT: const TestInheritanceMultiBaseInterface::Concept
@@ -173,10 +168,16 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
173168
// DECL: /// some function comment
174169
// DECL: int foo(int input);
175170

176-
// DECL-LABEL: struct TestOpInterfaceVerifyTrait
171+
// Trait declarations / definitions come after interface definitions.
172+
// DECL: struct ExtraShardDeclsInterfaceTrait : public
173+
// DECL: bool sharedMethodDeclaration() {
174+
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
175+
// DECL-NEXT: }
176+
177+
// DECL-LABEL: struct TestOpInterfaceVerifyTrait : public
177178
// DECL: verifyTrait
178179

179-
// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
180+
// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait : public
180181
// DECL: verifyRegionTrait
181182

182183
// Method implementations come last, after all class definitions.

mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ class InterfaceGenerator {
9696
void emitConceptDecl(const Interface &interface);
9797
void emitModelDecl(const Interface &interface);
9898
void emitModelMethodsDef(const Interface &interface);
99-
void emitTraitDecl(const Interface &interface, StringRef interfaceName,
100-
StringRef interfaceTraitsName);
99+
void forwardDeclareInterface(const Interface &interface);
101100
void emitInterfaceDecl(const Interface &interface);
101+
void emitInterfaceTraitDecl(const Interface &interface);
102102

103103
/// The set of interface records to emit.
104104
std::vector<const Record *> defs;
@@ -445,9 +445,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
445445
os << "} // namespace " << ns << "\n";
446446
}
447447

448-
void InterfaceGenerator::emitTraitDecl(const Interface &interface,
449-
StringRef interfaceName,
450-
StringRef interfaceTraitsName) {
448+
void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
449+
llvm::SmallVector<StringRef, 2> namespaces;
450+
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
451+
for (StringRef ns : namespaces)
452+
os << "namespace " << ns << " {\n";
453+
454+
os << "namespace detail {\n";
455+
456+
StringRef interfaceName = interface.getName();
457+
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
451458
os << llvm::formatv(" template <typename {3}>\n"
452459
" struct {0}Trait : public ::mlir::{2}<{0},"
453460
" detail::{1}>::Trait<{3}> {{\n",
@@ -494,6 +501,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
494501
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
495502

496503
os << " };\n";
504+
os << "}// namespace detail\n";
505+
506+
for (StringRef ns : llvm::reverse(namespaces))
507+
os << "} // namespace " << ns << "\n";
497508
}
498509

499510
static void emitInterfaceDeclMethods(const Interface &interface,
@@ -517,6 +528,27 @@ static void emitInterfaceDeclMethods(const Interface &interface,
517528
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
518529
}
519530

531+
void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
532+
llvm::SmallVector<StringRef, 2> namespaces;
533+
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
534+
for (StringRef ns : namespaces)
535+
os << "namespace " << ns << " {\n";
536+
537+
// Emit a forward declaration of the interface class so that it becomes usable
538+
// in the signature of its methods.
539+
std::string comments = tblgen::emitSummaryAndDescComments(
540+
"", interface.getDescription().value_or(""));
541+
if (!comments.empty()) {
542+
os << comments << "\n";
543+
}
544+
545+
StringRef interfaceName = interface.getName();
546+
os << "class " << interfaceName << ";\n";
547+
548+
for (StringRef ns : llvm::reverse(namespaces))
549+
os << "} // namespace " << ns << "\n";
550+
}
551+
520552
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
521553
llvm::SmallVector<StringRef, 2> namespaces;
522554
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
@@ -533,7 +565,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
533565
if (!comments.empty()) {
534566
os << comments << "\n";
535567
}
536-
os << "class " << interfaceName << ";\n";
537568

538569
// Emit the traits struct containing the concept and model declarations.
539570
os << "namespace detail {\n"
@@ -603,10 +634,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
603634

604635
os << "};\n";
605636

606-
os << "namespace detail {\n";
607-
emitTraitDecl(interface, interfaceName, interfaceTraitsName);
608-
os << "}// namespace detail\n";
609-
610637
for (StringRef ns : llvm::reverse(namespaces))
611638
os << "} // namespace " << ns << "\n";
612639
}
@@ -619,10 +646,15 @@ bool InterfaceGenerator::emitInterfaceDecls() {
619646
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
620647
return lhs->getID() < rhs->getID();
621648
});
649+
for (const Record *def : sortedDefs)
650+
forwardDeclareInterface(Interface(def));
622651
for (const Record *def : sortedDefs)
623652
emitInterfaceDecl(Interface(def));
653+
for (const Record *def : sortedDefs)
654+
emitInterfaceTraitDecl(Interface(def));
624655
for (const Record *def : sortedDefs)
625656
emitModelMethodsDef(Interface(def));
657+
626658
return false;
627659
}
628660

0 commit comments

Comments
 (0)