Skip to content

Commit 9621722

Browse files
ftynsekiranchandramohan
authored andcommitted
[mlir] Enable delayed registration of attribute/operation/type interfaces
This functionality is similar to delayed registration of dialect interfaces. It allows external interface models to be registered before the dialect containing the attribute/operation/type interface is loaded, or even before the context is created. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D104397
1 parent ecd06db commit 9621722

File tree

7 files changed

+273
-19
lines changed

7 files changed

+273
-19
lines changed

mlir/include/mlir/IR/AttributeSupport.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class AbstractAttribute {
5050
return interfaceMap.lookup<T>();
5151
}
5252

53+
/// Returns true if the attribute has the interface with the given ID
54+
/// registered.
55+
bool hasInterface(TypeID interfaceID) const {
56+
return interfaceMap.contains(interfaceID);
57+
}
58+
5359
/// Return the unique identifier representing the concrete attribute class.
5460
TypeID getTypeID() const { return typeID; }
5561

mlir/include/mlir/IR/Dialect.h

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ class Type;
2727

2828
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
2929
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
30-
using InterfaceAllocatorFunction =
30+
using DialectInterfaceAllocatorFunction =
3131
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
32+
using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
3233

3334
/// Dialects are groups of MLIR operations, types and attributes, as well as
3435
/// behavior associated with the entire group. For example, hooks into other
@@ -271,11 +272,19 @@ class Dialect {
271272
/// dialects loaded in the Context. The parser in particular will lazily load
272273
/// dialects in the Context as operations are encountered.
273274
class DialectRegistry {
275+
/// Lists of interfaces that need to be registered when the dialect is loaded.
276+
struct DelayedInterfaces {
277+
/// Dialect interfaces.
278+
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
279+
dialectInterfaces;
280+
/// Attribute/Operation/Type interfaces.
281+
SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
282+
objectInterfaces;
283+
};
284+
274285
using MapTy =
275286
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
276-
using InterfaceMapTy =
277-
DenseMap<TypeID,
278-
SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
287+
using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
279288

280289
public:
281290
explicit DialectRegistry() {}
@@ -329,7 +338,7 @@ class DialectRegistry {
329338
/// the registry.
330339
template <typename DialectTy>
331340
void addDialectInterface(TypeID interfaceTypeID,
332-
InterfaceAllocatorFunction allocator) {
341+
DialectInterfaceAllocatorFunction allocator) {
333342
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
334343
allocator);
335344
}
@@ -344,6 +353,36 @@ class DialectRegistry {
344353
});
345354
}
346355

356+
/// Add an external op interface model for an op that belongs to a dialect,
357+
/// both provided as template parameters. The dialect must be present in the
358+
/// registry.
359+
template <typename OpTy, typename ModelTy>
360+
void addOpInterface() {
361+
StringRef opName = OpTy::getOperationName();
362+
StringRef dialectName = opName.split('.').first;
363+
addObjectInterface(dialectName == opName ? "" : dialectName,
364+
ModelTy::Interface::getInterfaceID(),
365+
[](MLIRContext *context) {
366+
OpTy::template attachInterface<ModelTy>(*context);
367+
});
368+
}
369+
370+
/// Add an external attribute interface model for an attribute type `AttrTy`
371+
/// that is going to belong to `DialectTy`. The dialect must be present in the
372+
/// registry.
373+
template <typename DialectTy, typename AttrTy, typename ModelTy>
374+
void addAttrInterface() {
375+
addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
376+
}
377+
378+
/// Add an external type interface model for an type class `TypeTy` that is
379+
/// going to belong to `DialectTy`. The dialect must be present in the
380+
/// registry.
381+
template <typename DialectTy, typename TypeTy, typename ModelTy>
382+
void addTypeInterface() {
383+
addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
384+
}
385+
347386
/// Register any interfaces required for the given dialect (based on its
348387
/// TypeID). Users are not expected to call this directly.
349388
void registerDelayedInterfaces(Dialect *dialect) const;
@@ -352,7 +391,22 @@ class DialectRegistry {
352391
/// Add an interface constructed with the given allocation function to the
353392
/// dialect identified by its namespace.
354393
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
355-
InterfaceAllocatorFunction allocator);
394+
DialectInterfaceAllocatorFunction allocator);
395+
396+
/// Add an attribute/operation/type interface constructible with the given
397+
/// allocation function to the dialect identified by its namespace.
398+
void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
399+
ObjectInterfaceAllocatorFunction allocator);
400+
401+
/// Add an external model for an attribute/type interface to the dialect
402+
/// identified by its namespace.
403+
template <typename ObjectTy, typename ModelTy>
404+
void addStorageUserInterface(StringRef dialectName) {
405+
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
406+
[](MLIRContext *context) {
407+
ObjectTy::template attachInterface<ModelTy>(*context);
408+
});
409+
}
356410

357411
MapTy registry;
358412
InterfaceMapTy interfaces;

mlir/include/mlir/IR/TypeSupport.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class AbstractType {
5858
return interfaceMap.lookup<T>();
5959
}
6060

61+
/// Returns true if the type has the interface with the given ID.
62+
bool hasInterface(TypeID interfaceID) const {
63+
return interfaceMap.contains(interfaceID);
64+
}
65+
6166
/// Return the unique identifier representing the concrete type class.
6267
TypeID getTypeID() const { return typeID; }
6368

mlir/include/mlir/Support/InterfaceSupport.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Support/TypeID.h"
1717
#include "llvm/ADT/DenseMap.h"
18+
#include "llvm/Support/Debug.h"
1819
#include "llvm/Support/TypeName.h"
1920

2021
namespace mlir {
@@ -235,8 +236,10 @@ class InterfaceMap {
235236
llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
236237
return compare(it.first, id);
237238
});
238-
if (it != interfaces.end() && it->first == id)
239-
llvm::report_fatal_error("Interface already registered");
239+
if (it != interfaces.end() && it->first == id) {
240+
LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
241+
continue;
242+
}
240243
interfaces.insert(it, element);
241244
}
242245
}

mlir/lib/IR/Dialect.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/IR/Dialect.h"
10+
#include "mlir/IR/BuiltinDialect.h"
1011
#include "mlir/IR/Diagnostics.h"
1112
#include "mlir/IR/DialectImplementation.h"
1213
#include "mlir/IR/DialectInterface.h"
@@ -31,7 +32,7 @@ DialectAsmParser::~DialectAsmParser() {}
3132

3233
void DialectRegistry::addDialectInterface(
3334
StringRef dialectName, TypeID interfaceTypeID,
34-
InterfaceAllocatorFunction allocator) {
35+
DialectInterfaceAllocatorFunction allocator) {
3536
assert(allocator && "unexpected null interface allocation function");
3637
auto it = registry.find(dialectName.str());
3738
assert(it != registry.end() &&
@@ -40,8 +41,8 @@ void DialectRegistry::addDialectInterface(
4041
// Bail out if the interface with the given ID is already in the registry for
4142
// the given dialect. We expect a small number (dozens) of interfaces so a
4243
// linear search is fine here.
43-
auto &dialectInterfaces = interfaces[it->second.first];
44-
for (const auto &kvp : dialectInterfaces) {
44+
auto &ifaces = interfaces[it->second.first];
45+
for (const auto &kvp : ifaces.dialectInterfaces) {
4546
if (kvp.first == interfaceTypeID) {
4647
LLVM_DEBUG(llvm::dbgs()
4748
<< "[" DEBUG_TYPE
@@ -51,7 +52,36 @@ void DialectRegistry::addDialectInterface(
5152
}
5253
}
5354

54-
dialectInterfaces.emplace_back(interfaceTypeID, allocator);
55+
ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
56+
}
57+
58+
void DialectRegistry::addObjectInterface(
59+
StringRef dialectName, TypeID interfaceTypeID,
60+
ObjectInterfaceAllocatorFunction allocator) {
61+
assert(allocator && "unexpected null interface allocation function");
62+
63+
// Builtin dialect has an empty prefix and is always registered.
64+
TypeID dialectTypeID;
65+
if (!dialectName.empty()) {
66+
auto it = registry.find(dialectName.str());
67+
assert(it != registry.end() &&
68+
"adding an interface for an op from an unregistered dialect");
69+
dialectTypeID = it->second.first;
70+
} else {
71+
dialectTypeID = TypeID::get<BuiltinDialect>();
72+
}
73+
74+
auto &ifaces = interfaces[dialectTypeID];
75+
for (const auto &kvp : ifaces.objectInterfaces) {
76+
if (kvp.first == interfaceTypeID) {
77+
LLVM_DEBUG(llvm::dbgs()
78+
<< "[" DEBUG_TYPE
79+
"] repeated interface object interface registration");
80+
return;
81+
}
82+
}
83+
84+
ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
5585
}
5686

5787
DialectAllocatorFunctionRef
@@ -79,11 +109,15 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
79109
return;
80110

81111
// Add an interface if it is not already present.
82-
for (const auto &kvp : it->second) {
112+
for (const auto &kvp : it->getSecond().dialectInterfaces) {
83113
if (dialect->getRegisteredInterface(kvp.first))
84114
continue;
85115
dialect->addInterface(kvp.second(dialect));
86116
}
117+
118+
// Add attribute, operation and type interfaces.
119+
for (const auto &kvp : it->getSecond().objectInterfaces)
120+
kvp.second(dialect->getContext());
87121
}
88122

89123
//===----------------------------------------------------------------------===//

mlir/lib/IR/MLIRContext.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,12 @@ MLIRContext::MLIRContext(const DialectRegistry &registry)
355355
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
356356
}
357357

358-
// Ensure the builtin dialect is always pre-loaded.
359-
getOrLoadDialect<BuiltinDialect>();
360-
361358
// Pre-populate the registry.
362359
registry.appendTo(impl->dialectsRegistry);
363360

361+
// Ensure the builtin dialect is always pre-loaded.
362+
getOrLoadDialect<BuiltinDialect>();
363+
364364
// Initialize several common attributes and types to avoid the need to lock
365365
// the context when accessing them.
366366

0 commit comments

Comments
 (0)