Skip to content

Commit a98cfd3

Browse files
[mlir][TableGen][NFC] Emit interface traits after all interfaces
Interface traits may provide default implementation of methods. When this happens, the implementation may rely on another interface that is not yet defined meaning that one gets "incomplete type" error during C++ compilation. In pseudo-code, the problem is the following: ``` InterfaceA has methodB() { return InterfaceB(); } InterfaceB defined later // What's generated is: class InterfaceA { ... } class InterfaceATrait { // error: InterfaceB is an incomplete type InterfaceB methodB() { return InterfaceB(); } } class InterfaceB { ... } // defined here ``` The two more "advanced" cases are: * Cyclic dependency (A requires B and B requires A) * Type-traited usage of an incomplete type (e.g. `FailureOr<InterfaceB>`) It seems reasonable to emit interface traits *after* all of the interfaces have been defined to avoid the problem altogether. As a drive by, make forward declarations of the interfaces early so that user code does not need to forward declare.
1 parent e0a33cb commit a98cfd3

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

mlir/test/lib/Dialect/Test/TestInterfaces.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,32 @@ def TestOptionallyImplementedTypeInterface
174174
}];
175175
}
176176

177+
// Dummy type interface "A" that requires type interface "B" to be complete.
178+
def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> {
179+
let cppNamespace = "::mlir";
180+
let methods = [
181+
InterfaceMethod<"",
182+
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>",
183+
/*methodName=*/"returnB",
184+
(ins),
185+
/*methodBody=*/"",
186+
/*defaultImpl=*/"return mlir::failure();"
187+
>,
188+
];
189+
}
190+
191+
// Dummy type interface "B" that requires type interface "A" to be complete.
192+
def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> {
193+
let cppNamespace = "::mlir";
194+
let methods = [
195+
InterfaceMethod<"",
196+
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>",
197+
/*methodName=*/"returnA",
198+
(ins),
199+
/*methodBody=*/"",
200+
/*defaultImpl=*/"return mlir::failure();"
201+
>,
202+
];
203+
}
204+
177205
#endif // MLIR_TEST_DIALECT_TEST_INTERFACES

mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ class InterfaceGenerator {
9696
void emitConceptDecl(const Interface &interface);
9797
void emitModelDecl(const Interface &interface);
9898
void emitModelMethodsDef(const Interface &interface);
99-
void emitTraitDecl(const Interface &interface, StringRef interfaceName,
100-
StringRef interfaceTraitsName);
99+
void forwardDeclareInterface(const Interface &interface);
101100
void emitInterfaceDecl(const Interface &interface);
101+
void emitInterfaceTraitDecl(const Interface &interface);
102102

103103
/// The set of interface records to emit.
104104
std::vector<const Record *> defs;
@@ -445,9 +445,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
445445
os << "} // namespace " << ns << "\n";
446446
}
447447

448-
void InterfaceGenerator::emitTraitDecl(const Interface &interface,
449-
StringRef interfaceName,
450-
StringRef interfaceTraitsName) {
448+
void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
449+
llvm::SmallVector<StringRef, 2> namespaces;
450+
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
451+
for (StringRef ns : namespaces)
452+
os << "namespace " << ns << " {\n";
453+
454+
os << "namespace detail {\n";
455+
456+
StringRef interfaceName = interface.getName();
457+
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
451458
os << llvm::formatv(" template <typename {3}>\n"
452459
" struct {0}Trait : public ::mlir::{2}<{0},"
453460
" detail::{1}>::Trait<{3}> {{\n",
@@ -494,6 +501,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
494501
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
495502

496503
os << " };\n";
504+
os << "}// namespace detail\n";
505+
506+
for (StringRef ns : llvm::reverse(namespaces))
507+
os << "} // namespace " << ns << "\n";
497508
}
498509

499510
static void emitInterfaceDeclMethods(const Interface &interface,
@@ -517,6 +528,19 @@ static void emitInterfaceDeclMethods(const Interface &interface,
517528
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
518529
}
519530

531+
void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
532+
llvm::SmallVector<StringRef, 2> namespaces;
533+
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
534+
for (StringRef ns : namespaces)
535+
os << "namespace " << ns << " {\n";
536+
537+
StringRef interfaceName = interface.getName();
538+
os << "class " << interfaceName << ";\n";
539+
540+
for (StringRef ns : llvm::reverse(namespaces))
541+
os << "} // namespace " << ns << "\n";
542+
}
543+
520544
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
521545
llvm::SmallVector<StringRef, 2> namespaces;
522546
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
@@ -533,7 +557,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
533557
if (!comments.empty()) {
534558
os << comments << "\n";
535559
}
536-
os << "class " << interfaceName << ";\n";
537560

538561
// Emit the traits struct containing the concept and model declarations.
539562
os << "namespace detail {\n"
@@ -603,10 +626,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
603626

604627
os << "};\n";
605628

606-
os << "namespace detail {\n";
607-
emitTraitDecl(interface, interfaceName, interfaceTraitsName);
608-
os << "}// namespace detail\n";
609-
610629
for (StringRef ns : llvm::reverse(namespaces))
611630
os << "} // namespace " << ns << "\n";
612631
}
@@ -619,10 +638,15 @@ bool InterfaceGenerator::emitInterfaceDecls() {
619638
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
620639
return lhs->getID() < rhs->getID();
621640
});
641+
for (const Record *def : sortedDefs)
642+
forwardDeclareInterface(Interface(def));
622643
for (const Record *def : sortedDefs)
623644
emitInterfaceDecl(Interface(def));
624645
for (const Record *def : sortedDefs)
625646
emitModelMethodsDef(Interface(def));
647+
for (const Record *def : sortedDefs)
648+
emitInterfaceTraitDecl(Interface(def));
649+
626650
return false;
627651
}
628652

0 commit comments

Comments
 (0)