@@ -27,8 +27,9 @@ class Type;
27
27
28
28
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
29
29
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
30
- using InterfaceAllocatorFunction =
30
+ using DialectInterfaceAllocatorFunction =
31
31
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
32
+ using ObjectInterfaceAllocatorFunction = std::function<void (MLIRContext *)>;
32
33
33
34
// / Dialects are groups of MLIR operations, types and attributes, as well as
34
35
// / behavior associated with the entire group. For example, hooks into other
@@ -271,11 +272,19 @@ class Dialect {
271
272
// / dialects loaded in the Context. The parser in particular will lazily load
272
273
// / dialects in the Context as operations are encountered.
273
274
class DialectRegistry {
275
+ // / Lists of interfaces that need to be registered when the dialect is loaded.
276
+ struct DelayedInterfaces {
277
+ // / Dialect interfaces.
278
+ SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2 >
279
+ dialectInterfaces;
280
+ // / Attribute/Operation/Type interfaces.
281
+ SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2 >
282
+ objectInterfaces;
283
+ };
284
+
274
285
using MapTy =
275
286
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
276
- using InterfaceMapTy =
277
- DenseMap<TypeID,
278
- SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2 >>;
287
+ using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
279
288
280
289
public:
281
290
explicit DialectRegistry () {}
@@ -329,7 +338,7 @@ class DialectRegistry {
329
338
// / the registry.
330
339
template <typename DialectTy>
331
340
void addDialectInterface (TypeID interfaceTypeID,
332
- InterfaceAllocatorFunction allocator) {
341
+ DialectInterfaceAllocatorFunction allocator) {
333
342
addDialectInterface (DialectTy::getDialectNamespace (), interfaceTypeID,
334
343
allocator);
335
344
}
@@ -344,6 +353,36 @@ class DialectRegistry {
344
353
});
345
354
}
346
355
356
+ // / Add an external op interface model for an op that belongs to a dialect,
357
+ // / both provided as template parameters. The dialect must be present in the
358
+ // / registry.
359
+ template <typename OpTy, typename ModelTy>
360
+ void addOpInterface () {
361
+ StringRef opName = OpTy::getOperationName ();
362
+ StringRef dialectName = opName.split (' .' ).first ;
363
+ addObjectInterface (dialectName == opName ? " " : dialectName,
364
+ ModelTy::Interface::getInterfaceID (),
365
+ [](MLIRContext *context) {
366
+ OpTy::template attachInterface<ModelTy>(*context);
367
+ });
368
+ }
369
+
370
+ // / Add an external attribute interface model for an attribute type `AttrTy`
371
+ // / that is going to belong to `DialectTy`. The dialect must be present in the
372
+ // / registry.
373
+ template <typename DialectTy, typename AttrTy, typename ModelTy>
374
+ void addAttrInterface () {
375
+ addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace ());
376
+ }
377
+
378
+ // / Add an external type interface model for an type class `TypeTy` that is
379
+ // / going to belong to `DialectTy`. The dialect must be present in the
380
+ // / registry.
381
+ template <typename DialectTy, typename TypeTy, typename ModelTy>
382
+ void addTypeInterface () {
383
+ addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace ());
384
+ }
385
+
347
386
// / Register any interfaces required for the given dialect (based on its
348
387
// / TypeID). Users are not expected to call this directly.
349
388
void registerDelayedInterfaces (Dialect *dialect) const ;
@@ -352,7 +391,22 @@ class DialectRegistry {
352
391
// / Add an interface constructed with the given allocation function to the
353
392
// / dialect identified by its namespace.
354
393
void addDialectInterface (StringRef dialectName, TypeID interfaceTypeID,
355
- InterfaceAllocatorFunction allocator);
394
+ DialectInterfaceAllocatorFunction allocator);
395
+
396
+ // / Add an attribute/operation/type interface constructible with the given
397
+ // / allocation function to the dialect identified by its namespace.
398
+ void addObjectInterface (StringRef dialectName, TypeID interfaceTypeID,
399
+ ObjectInterfaceAllocatorFunction allocator);
400
+
401
+ // / Add an external model for an attribute/type interface to the dialect
402
+ // / identified by its namespace.
403
+ template <typename ObjectTy, typename ModelTy>
404
+ void addStorageUserInterface (StringRef dialectName) {
405
+ addObjectInterface (dialectName, ModelTy::Interface::getInterfaceID (),
406
+ [](MLIRContext *context) {
407
+ ObjectTy::template attachInterface<ModelTy>(*context);
408
+ });
409
+ }
356
410
357
411
MapTy registry;
358
412
InterfaceMapTy interfaces;
0 commit comments