Skip to content

Commit 842622b

Browse files
authored
[MLIR][ODS] Add support for overloading interface methods (#161828)
This allows to define multiple interface methods with the same name but different arguments.
1 parent 199811d commit 842622b

File tree

10 files changed

+114
-21
lines changed

10 files changed

+114
-21
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments,
269269
using Triplet = std::tuple<mlir::Value, mlir::Value, mlir::Value>;
270270
using Subscript = std::variant<mlir::Value, Triplet>;
271271
using Subscripts = llvm::SmallVector<Subscript, 8>;
272+
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
273+
this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
274+
}
272275
}];
273276

274277
let builders = [
@@ -319,7 +322,7 @@ def hlfir_ParentComponentOp : hlfir_Op<"parent_comp", [AttrSizedOperandSegments,
319322
// Implement FortranVariableInterface interface. Parent components have
320323
// no attributes (pointer, allocatable or contiguous can only be added
321324
// to regular components).
322-
std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() const {
325+
std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() {
323326
return std::nullopt;
324327
}
325328
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {}
@@ -882,6 +885,10 @@ def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
882885
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];
883886

884887
let extraClassDeclaration = [{
888+
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
889+
this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
890+
}
891+
885892
/// Override FortranVariableInterface default implementation
886893
mlir::Value getBase() {
887894
return getResult(0);

mlir/include/mlir/TableGen/Interfaces.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@ class InterfaceMethod {
3232
StringRef name;
3333
};
3434

35-
explicit InterfaceMethod(const llvm::Record *def);
35+
explicit InterfaceMethod(const llvm::Record *def, std::string uniqueName);
3636

3737
// Return the return type of this method.
3838
StringRef getReturnType() const;
3939

4040
// Return the name of this method.
4141
StringRef getName() const;
4242

43+
// Return the dedup name of this method.
44+
StringRef getUniqueName() const;
45+
4346
// Return if this method is static.
4447
bool isStatic() const;
4548

@@ -62,6 +65,10 @@ class InterfaceMethod {
6265

6366
// The arguments of this method.
6467
SmallVector<Argument, 2> arguments;
68+
69+
// The unique name of this method, to distinguish it from other methods with
70+
// the same name (overloaded methods)
71+
std::string uniqueName;
6572
};
6673

6774
//===----------------------------------------------------------------------===//

mlir/lib/TableGen/Interfaces.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ using llvm::StringInit;
2525
// InterfaceMethod
2626
//===----------------------------------------------------------------------===//
2727

28-
InterfaceMethod::InterfaceMethod(const Record *def) : def(def) {
28+
InterfaceMethod::InterfaceMethod(const Record *def, std::string uniqueName)
29+
: def(def), uniqueName(uniqueName) {
2930
const DagInit *args = def->getValueAsDag("arguments");
3031
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
3132
arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(),
@@ -42,6 +43,9 @@ StringRef InterfaceMethod::getName() const {
4243
return def->getValueAsString("name");
4344
}
4445

46+
// Return the name of this method.
47+
StringRef InterfaceMethod::getUniqueName() const { return uniqueName; }
48+
4549
// Return if this method is static.
4650
bool InterfaceMethod::isStatic() const {
4751
return def->isSubClassOf("StaticInterfaceMethod");
@@ -83,8 +87,19 @@ Interface::Interface(const Record *def) : def(def) {
8387

8488
// Initialize the interface methods.
8589
auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods"));
86-
for (const Init *init : listInit->getElements())
87-
methods.emplace_back(cast<DefInit>(init)->getDef());
90+
// In case of overloaded methods, we need to find a unique name for each for
91+
// the internal function pointer in the "vtable" we generate. This is an
92+
// internal name, we could use a randomly generated name as long as there are
93+
// no collisions.
94+
StringSet<> uniqueNames;
95+
for (const Init *init : listInit->getElements()) {
96+
std::string name =
97+
cast<DefInit>(init)->getDef()->getValueAsString("name").str();
98+
while (!uniqueNames.insert(name).second) {
99+
name = name + "_" + std::to_string(uniqueNames.size());
100+
}
101+
methods.emplace_back(cast<DefInit>(init)->getDef(), name);
102+
}
88103

89104
// Initialize the interface base classes.
90105
auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces"));

mlir/test/lib/Dialect/Test/TestInterfaces.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def TestTypeInterface
4444
InterfaceMethod<"Prints the type name.",
4545
"void", "printTypeC", (ins "::mlir::Location":$loc)
4646
>,
47+
// Check that we can have multiple method with the same name.
48+
InterfaceMethod<"Prints the type name, with a value prefixed.",
49+
"void", "printTypeC", (ins "::mlir::Location":$loc, "int":$value)
50+
>,
51+
InterfaceMethod<"Prints the type name, with a value prefixed.",
52+
"void", "printTypeC", (ins "::mlir::Location":$loc, "float":$value),
53+
[{}], /*defaultImplementation=*/[{
54+
emitRemark(loc) << $_type << " - " << value << " - Float TestC";
55+
}]
56+
>,
4757
// It should be possible to use the interface type name as result type
4858
// as well as in the implementation.
4959
InterfaceMethod<"Prints the type name and returns the type as interface.",

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
245245
emitRemark(loc) << *this << " - TestC";
246246
}
247247

248+
void TestType::printTypeC(Location loc, int value) const {
249+
emitRemark(loc) << *this << " - " << value << " - Int TestC";
250+
}
251+
248252
//===----------------------------------------------------------------------===//
249253
// TestTypeWithLayout
250254
//===----------------------------------------------------------------------===//

mlir/test/lib/IR/TestInterfaces.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct TestTypeInterfaces
3131
testInterface.printTypeA(op->getLoc());
3232
testInterface.printTypeB(op->getLoc());
3333
testInterface.printTypeC(op->getLoc());
34+
testInterface.printTypeC(op->getLoc(), 42);
35+
testInterface.printTypeC(op->getLoc(), 3.14f);
3436
testInterface.printTypeD(op->getLoc());
3537
// Just check that we can assign the result to a variable of interface
3638
// type.

mlir/test/mlir-tblgen/interfaces.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// expected-remark@below {{'!test.test_type' - TestA}}
44
// expected-remark@below {{'!test.test_type' - TestB}}
55
// expected-remark@below {{'!test.test_type' - TestC}}
6+
// expected-remark@below {{'!test.test_type' - 42 - Int TestC}}
7+
// expected-remark@below {{'!test.test_type' - 3.140000e+00 - Float TestC}}
68
// expected-remark@below {{'!test.test_type' - TestD}}
79
// expected-remark@below {{'!test.test_type' - TestRet}}
810
// expected-remark@below {{'!test.test_type' - TestE}}

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ class DefGen {
130130
void emitTraitMethods(const InterfaceTrait &trait);
131131
/// Emit a trait method.
132132
void emitTraitMethod(const InterfaceMethod &method);
133+
/// Generate a using declaration for a trait method.
134+
void genTraitMethodUsingDecl(const InterfaceTrait &trait,
135+
const InterfaceMethod &method);
133136

134137
//===--------------------------------------------------------------------===//
135138
// OpAsm{Type,Attr}Interface Default Method Emission
@@ -176,6 +179,9 @@ class DefGen {
176179
StringRef valueType;
177180
/// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
178181
StringRef defType;
182+
183+
/// The set of using declarations for trait methods.
184+
llvm::StringSet<> interfaceUsingNames;
179185
};
180186
} // namespace
181187

@@ -632,8 +638,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
632638
// Don't declare if the method has a body. Or if the method has a default
633639
// implementation and the def didn't request that it always be declared.
634640
if (method.getBody() || (method.getDefaultImplementation() &&
635-
!alwaysDeclared.count(method.getName())))
641+
!alwaysDeclared.count(method.getName()))) {
642+
genTraitMethodUsingDecl(trait, method);
636643
continue;
644+
}
637645
emitTraitMethod(method);
638646
}
639647
}
@@ -649,6 +657,15 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
649657
std::move(params));
650658
}
651659

660+
void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait,
661+
const InterfaceMethod &method) {
662+
std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" +
663+
def.getCppClassName() + ">::" + method.getName())
664+
.str();
665+
if (interfaceUsingNames.insert(name).second)
666+
defCls.declare<UsingDeclaration>(std::move(name));
667+
}
668+
652669
//===----------------------------------------------------------------------===//
653670
// OpAsm{Type,Attr}Interface Default Method Emission
654671

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,14 @@ class OpEmitter {
790790
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
791791
bool declaration = true);
792792

793+
// Generate a `using` declaration for the op interface method to include
794+
// the default implementation from the interface trait.
795+
// This is needed when the interface defines multiple methods with the same
796+
// name, but some have a default implementation and some don't.
797+
UsingDeclaration *
798+
genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
799+
const tblgen::InterfaceMethod &method);
800+
793801
// Generate the side effect interface methods.
794802
void genSideEffectInterfaceMethods();
795803

@@ -816,6 +824,10 @@ class OpEmitter {
816824

817825
// Helper for emitting op code.
818826
OpOrAdaptorHelper emitHelper;
827+
828+
// Keep track of the interface using declarations that have been generated to
829+
// avoid duplicates.
830+
llvm::StringSet<> interfaceUsingNames;
819831
};
820832

821833
} // namespace
@@ -3673,8 +3685,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
36733685
// Don't declare if the method has a default implementation and the op
36743686
// didn't request that it always be declared.
36753687
if (method.getDefaultImplementation() &&
3676-
!alwaysDeclaredMethods.count(method.getName()))
3688+
!alwaysDeclaredMethods.count(method.getName())) {
3689+
genOpInterfaceMethodUsingDecl(opTrait, method);
36773690
continue;
3691+
}
36783692
// Interface methods are allowed to overlap with existing methods, so don't
36793693
// check if pruned.
36803694
(void)genOpInterfaceMethod(method);
@@ -3693,6 +3707,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
36933707
std::move(paramList));
36943708
}
36953709

3710+
UsingDeclaration *
3711+
OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
3712+
const InterfaceMethod &method) {
3713+
std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
3714+
op.getCppClassName() + ">::" + method.getName())
3715+
.str();
3716+
if (interfaceUsingNames.insert(name).second)
3717+
return opClass.declare<UsingDeclaration>(std::move(name));
3718+
return nullptr;
3719+
}
3720+
36963721
void OpEmitter::genOpInterfaceMethods() {
36973722
for (const auto &trait : op.getTraits()) {
36983723
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))

mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
4242
/// Emit the method name and argument list for the given method. If 'addThisArg'
4343
/// is true, then an argument is added to the beginning of the argument list for
4444
/// the concrete value.
45-
static void emitMethodNameAndArgs(const InterfaceMethod &method,
45+
static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
4646
raw_ostream &os, StringRef valueType,
4747
bool addThisArg, bool addConst) {
48-
os << method.getName() << '(';
48+
os << name << '(';
4949
if (addThisArg) {
5050
if (addConst)
5151
os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
183183
emitInterfaceMethodDoc(method, os);
184184
emitCPPType(method.getReturnType(), os);
185185
os << interfaceQualName << "::";
186-
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
186+
emitMethodNameAndArgs(method, method.getName(), os, valueType,
187+
/*addThisArg=*/false,
187188
/*addConst=*/!isOpInterface);
188189

189190
// Forward to the method on the concrete operation type.
190-
os << " {\n return " << implValue << "->" << method.getName() << '(';
191+
os << " {\n return " << implValue << "->" << method.getUniqueName()
192+
<< '(';
191193
if (!method.isStatic()) {
192194
os << implValue << ", ";
193195
os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
239241
for (auto &method : interface.getMethods()) {
240242
os << " ";
241243
emitCPPType(method.getReturnType(), os);
242-
os << "(*" << method.getName() << ")(";
244+
os << "(*" << method.getUniqueName() << ")(";
243245
if (!method.isStatic()) {
244246
os << "const Concept *impl, ";
245247
emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
289291
os << " " << modelClass << "() : Concept{";
290292
llvm::interleaveComma(
291293
interface.getMethods(), os,
292-
[&](const InterfaceMethod &method) { os << method.getName(); });
294+
[&](const InterfaceMethod &method) { os << method.getUniqueName(); });
293295
os << "} {}\n\n";
294296

295297
// Insert each of the virtual method overrides.
296298
for (auto &method : interface.getMethods()) {
297299
emitCPPType(method.getReturnType(), os << " static inline ");
298-
emitMethodNameAndArgs(method, os, valueType,
300+
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
299301
/*addThisArg=*/!method.isStatic(),
300302
/*addConst=*/false);
301303
os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
319321
if (method.isStatic())
320322
os << "static ";
321323
emitCPPType(method.getReturnType(), os);
322-
os << method.getName() << "(";
324+
os << method.getUniqueName() << "(";
323325
if (!method.isStatic()) {
324326
emitCPPType(valueType, os);
325327
os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
350352
emitCPPType(method.getReturnType(), os);
351353
os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
352354
<< valueTemplate << ">::";
353-
emitMethodNameAndArgs(method, os, valueType,
355+
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
354356
/*addThisArg=*/!method.isStatic(),
355357
/*addConst=*/false);
356358
os << " {\n ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
384386
emitCPPType(method.getReturnType(), os);
385387
os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
386388
<< valueTemplate << ">::";
387-
emitMethodNameAndArgs(method, os, valueType,
389+
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
388390
/*addThisArg=*/!method.isStatic(),
389391
/*addConst=*/false);
390392
os << " {\n ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
396398
os << "return static_cast<const " << valueTemplate << " *>(impl)->";
397399

398400
// Add the arguments to the call.
399-
os << method.getName() << '(';
401+
os << method.getUniqueName() << '(';
400402
if (!method.isStatic())
401403
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
402404
llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
416418
<< "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
417419
<< ">::";
418420

419-
os << method.getName() << "(";
421+
os << method.getUniqueName() << "(";
420422
if (!method.isStatic()) {
421423
emitCPPType(valueType, os);
422424
os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
477479
emitInterfaceMethodDoc(method, os, " ");
478480
os << " " << (method.isStatic() ? "static " : "");
479481
emitCPPType(method.getReturnType(), os);
480-
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
482+
emitMethodNameAndArgs(method, method.getName(), os, valueType,
483+
/*addThisArg=*/false,
481484
/*addConst=*/!isOpInterface && !method.isStatic());
482485
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
483486
<< "\n }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
514517
for (auto &method : interface.getMethods()) {
515518
emitInterfaceMethodDoc(method, os, " ");
516519
emitCPPType(method.getReturnType(), os << " ");
517-
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
520+
emitMethodNameAndArgs(method, method.getName(), os, valueType,
521+
/*addThisArg=*/false,
518522
/*addConst=*/!isOpInterface);
519523
os << ";\n";
520524
}

0 commit comments

Comments
 (0)