Skip to content

Conversation

@joker-eph
Copy link
Collaborator

This allows to define multiple interface methods with the same name but different arguments.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This allows to define multiple interface methods with the same name but different arguments.


Full diff: https://github.com/llvm/llvm-project/pull/161828.diff

5 Files Affected:

  • (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+4)
  • (modified) mlir/test/lib/IR/TestInterfaces.cpp (+1)
  • (modified) mlir/test/mlir-tblgen/interfaces.mlir (+1)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+26-1)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+18-14)
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..9076c7e54d7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
   emitRemark(loc) << *this << " - TestC";
 }
 
+void TestType::printTypeC(Location loc, int value) const {
+  emitRemark(loc) << *this << " - " << value << " - TestC";
+}
+
 //===----------------------------------------------------------------------===//
 // TestTypeWithLayout
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 2dd3fe245e220..e021f78e1142d 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -31,6 +31,7 @@ struct TestTypeInterfaces
           testInterface.printTypeA(op->getLoc());
           testInterface.printTypeB(op->getLoc());
           testInterface.printTypeC(op->getLoc());
+          testInterface.printTypeC(op->getLoc(), 42);
           testInterface.printTypeD(op->getLoc());
           // Just check that we can assign the result to a variable of interface
           // type.
diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 5c1ec613b387a..927cfd728bcd4 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -3,6 +3,7 @@
 // expected-remark@below {{'!test.test_type' - TestA}}
 // expected-remark@below {{'!test.test_type' - TestB}}
 // expected-remark@below {{'!test.test_type' - TestC}}
+// expected-remark@below {{'!test.test_type' - 42 - TestC}}
 // expected-remark@below {{'!test.test_type' - TestD}}
 // expected-remark@below {{'!test.test_type' - TestRet}}
 // expected-remark@below {{'!test.test_type' - TestE}}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7e8e559baf878..4c6519cd2f7bf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -789,6 +789,14 @@ class OpEmitter {
   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
                                bool declaration = true);
 
+  // Generate a `using` declaration for the op interface method to include
+  // the default implementation from the interface trait.
+  // This is needed when the interface defines multiple methods with the same
+  // name, but some have a default implementation and some don't.
+  UsingDeclaration *
+  genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                const tblgen::InterfaceMethod &method);
+
   // Generate the side effect interface methods.
   void genSideEffectInterfaceMethods();
 
@@ -815,6 +823,10 @@ class OpEmitter {
 
   // Helper for emitting op code.
   OpOrAdaptorHelper emitHelper;
+
+  // Keep track of the interface using declarations that have been generated to
+  // avoid duplicates.
+  llvm::StringSet<> interfaceUsingNames;
 };
 
 } // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
     // Don't declare if the method has a default implementation and the op
     // didn't request that it always be declared.
     if (method.getDefaultImplementation() &&
-        !alwaysDeclaredMethods.count(method.getName()))
+        !alwaysDeclaredMethods.count(method.getName())) {
+      genOpInterfaceMethodUsingDecl(opTrait, method);
       continue;
+    }
     // Interface methods are allowed to overlap with existing methods, so don't
     // check if pruned.
     (void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
                            std::move(paramList));
 }
 
+UsingDeclaration *
+OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                         const InterfaceMethod &method) {
+  std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
+                      op.getQualCppClassName() + ">::" + method.getName())
+                         .str();
+  if (interfaceUsingNames.insert(name).second)
+    return opClass.declare<UsingDeclaration>(std::move(name));
+  return nullptr;
+}
+
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 3cc1636ac3317..9dedd55005f87 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
 /// Emit the method name and argument list for the given method. If 'addThisArg'
 /// is true, then an argument is added to the beginning of the argument list for
 /// the concrete value.
-static void emitMethodNameAndArgs(const InterfaceMethod &method,
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
                                   raw_ostream &os, StringRef valueType,
                                   bool addThisArg, bool addConst) {
-  os << method.getName() << '(';
+  os << name << '(';
   if (addThisArg) {
     if (addConst)
       os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
     emitInterfaceMethodDoc(method, os);
     emitCPPType(method.getReturnType(), os);
     os << interfaceQualName << "::";
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
 
     // Forward to the method on the concrete operation type.
-    os << " {\n      return " << implValue << "->" << method.getName() << '(';
+    os << " {\n      return " << implValue << "->" << method.getDedupName()
+       << '(';
     if (!method.isStatic()) {
       os << implValue << ", ";
       os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
   for (auto &method : interface.getMethods()) {
     os << "    ";
     emitCPPType(method.getReturnType(), os);
-    os << "(*" << method.getName() << ")(";
+    os << "(*" << method.getDedupName() << ")(";
     if (!method.isStatic()) {
       os << "const Concept *impl, ";
       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     os << "    " << modelClass << "() : Concept{";
     llvm::interleaveComma(
         interface.getMethods(), os,
-        [&](const InterfaceMethod &method) { os << method.getName(); });
+        [&](const InterfaceMethod &method) { os << method.getDedupName(); });
     os << "} {}\n\n";
 
     // Insert each of the virtual method overrides.
     for (auto &method : interface.getMethods()) {
       emitCPPType(method.getReturnType(), os << "    static inline ");
-      emitMethodNameAndArgs(method, os, valueType,
+      emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                             /*addThisArg=*/!method.isStatic(),
                             /*addConst=*/false);
       os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     if (method.isStatic())
       os << "static ";
     emitCPPType(method.getReturnType(), os);
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
 
     // Add the arguments to the call.
-    os << method.getName() << '(';
+    os << method.getDedupName() << '(';
     if (!method.isStatic())
       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
     llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
        << ">::";
 
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
     emitInterfaceMethodDoc(method, os, "    ");
     os << "    " << (method.isStatic() ? "static " : "");
     emitCPPType(method.getReturnType(), os);
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface && !method.isStatic());
     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
        << "\n    }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
   for (auto &method : interface.getMethods()) {
     emitInterfaceMethodDoc(method, os, "  ");
     emitCPPType(method.getReturnType(), os << "  ");
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
     os << ";\n";
   }

@joker-eph joker-eph force-pushed the interface_overloading branch 4 times, most recently from 9ddecbc to d603c93 Compare October 3, 2025 14:48
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 3, 2025
std::string name =
cast<DefInit>(init)->getDef()->getValueAsString("name").str();
while (!dedupNames.insert(name).second) {
name = name + "_" + std::to_string(dedupNames.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come a single counter is sufficient here? I expected name mangling scheme that makes the argument types part of the function name. When the interface method is called, how do you know which dedup method to call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the interface method is called, how do you know which dedup method to call?

The actual deduce method name could be a randomly generated name, as long as the mapping is consistent.
So we can initialize InterfaceMethod with any dedup name as long as it is uniqued.
This is all just internal to the "vtables" we generate. The public interface/trait will expose the overloaded methods with their public name, and try to call the method on the Op with the public name as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments about this, PTAL.

This allows to define multiple interface methods with the same name
but different arguments.
@joker-eph joker-eph force-pushed the interface_overloading branch from d603c93 to c5f6b0f Compare October 6, 2025 15:03
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a useful feature. I had to work around this limitation in the past: there is BufferizableOpInterface::bufferizesToMemoryWrite and BufferizableOpInterface::resultBufferizesToMemoryWrite.

@joker-eph joker-eph merged commit 842622b into llvm:main Oct 6, 2025
9 checks passed
@joker-eph joker-eph deleted the interface_overloading branch October 6, 2025 19:21
@qedawkins
Copy link
Contributor

This change is causing stale .h.inc build failures when adding an interface method with default impl, rebuilding, then removing the interface method. For example, if I apply

diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.td
index ce14b80d83..410043752b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.td
@@ -93,6 +93,7 @@ def IREECodegen_UKernelGenericOp :
     //     the `strided_dims` attribute is null.
     //   - Returns the corresponding dim list in `strided_dims` for ShapedType
     //     operands if `strided_dims` is not null.
+    //
     SmallVector<int64_t> getOperandStridedDims(int64_t operandIdx);
   }];
 }
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/UKernelOpInterface.td b/compiler/src/iree/compiler/Codegen/Interfaces/UKernelOpInterface.td
index f5667c996e..eb4d218487 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/UKernelOpInterface.td
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/UKernelOpInterface.td
@@ -24,7 +24,17 @@ def UKernelOpInterface : OpInterface<"UKernelOpInterface"> {
       /*methodName=*/"lowerToFunctionCall",
       /*args=*/(ins "RewriterBase &":$rewriter),
       /*methodBody=*/"",
       /*defautImplementation=*/"return failure();"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Copy of the above.
+      }],
+      /*retType=*/"FailureOr<mlir::CallOpInterface>",
+      /*methodName=*/"lowerToFunctionCall2",
+      /*args=*/(ins "RewriterBase &":$rewriter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return failure();"
     >,
   ];
 }

downstream in IREE, rebuild, then remove the new interface method and rebuild again I get something like

/home/quinn/iree-build/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h.inc:219:92: error: no member named 'lowerToFunctionCall2' in 'mlir::iree_compiler::IREE::Codegen::UKernelOpInterface::Trait<mlir::iree_compiler::IREE::Codegen::UKernelGenericOp>';

due to the using lines added here https://github.com/llvm/llvm-project/pull/161828/files#diff-9185f7f3f3628467e62e1712b01efc482eecc3643b67746c1eae73f5fd272138R3688.

I was able to repro this with TilingInterface + Linalg as well. The problem is that when a new interface method is added to the interface but NOT added to Declare*InterfaceMethods, these using lines are generated only when the interface implementer's tablegen changes because deps between .h.inc files appear to be order only.

@joker-eph
Copy link
Collaborator Author

Thanks for the report!

The problem is that when a new interface method is added to the interface but NOT added to Declare*InterfaceMethods, these using lines are generated only when the interface implementer's tablegen changes because deps between .h.inc files appear to be order only.

I don't quite follow this explanation right now, I would think that the operation file where Declare*InterfaceMethods is used depends on including the interface file and so anything processing the operation file should be rebuilt when any TableGen-included file is changed?
What do you mean by "order only" here?

@qedawkins
Copy link
Contributor

qedawkins commented Nov 5, 2025

Sorry about the ping, it look like this was an IREE specific issue with our logic for generating CMake files dropping .td deps. No problems here, you can ignore me.

This just happened to be the first time a .td file updating ever caused a problem like this. It was probably broken on our end for ages!

@joker-eph
Copy link
Collaborator Author

Sad, I was hopeful you'd have found a tricky dependency issues. I actually saw a windows bot running ninja having incorrect incremental build with a tablegen change recently (another PR), and I can't figure out what's wrong with our dependencies :(

@qedawkins
Copy link
Contributor

This was apparently the problem: https://github.com/iree-org/iree/pull/22554/files#diff-1e7de1ae2d059d21e1dd75d5812d5a34b0222cef273b7c3a2af62eb747f9d20aL9-L16

  ★ Insight ─────────────────────────────────────
  1. Root cause identified: IREE sets CMAKE_POLICY_DEFAULT_CMP0116 OLD, disabling LLVM's automatic dep tracking
  2. The comment is misleading: It says "LLVM requires CMP0116 for tblgen" but then sets it to OLD
  3. Fallback behavior: With CMP0116=OLD, tablegen() globs *.td files in include dirs instead of using depfiles
  ─────────────────────────────────────────────────

It looks like this was changed project wide earlier this year: #90385 + 2f08927

And I see a guard for WIN32 in that change too: 2f08927#diff-af6c3a2cba6f2d0e48d7e6f6c6ab7084527c8083d5c3ba64f02babec58e0a7eaR34

so maybe it is still a problem on windows 🤔 . I don't have a setup to test it though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants