Skip to content

Commit 14cde52

Browse files
ftynsekiranchandramohan
authored andcommitted
[mlir] separable registration of attribute and type interfaces
It may be desirable to provide an interface implementation for an attribute or a type without modifying the definition of said attribute or type. Notably, this allows to implement interfaces for attributes and types outside of the dialect that defines them and, in particular, provide interfaces for built-in types. Provide the mechanism to do so. Currently, separable registration requires the attribute or type to have been registered with the context, i.e. for the dialect containing the attribute or type to be loaded. This can be relaxed in the future using a mechanism similar to delayed dialect interface registration. See https://llvm.discourse.group/t/rfc-separable-attribute-type-interfaces/3637 Depends On D104233 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D104234
1 parent 77cad11 commit 14cde52

File tree

16 files changed

+474
-18
lines changed

16 files changed

+474
-18
lines changed

mlir/docs/Interfaces.md

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,91 @@ if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
207207
llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n";
208208
```
209209
210+
#### External Models for Attribute/Type Interfaces
211+
212+
It may be desirable to provide an interface implementation for an attribute or a
213+
type without modifying the definition of said attribute or type. Notably, this
214+
allows to implement interfaces for attributes and types outside of the dialect
215+
that defines them and, in particular, provide interfaces for built-in types.
216+
217+
This is achieved by extending the concept-based polymorphism model with two more
218+
classes derived from `Concept` as follows.
219+
220+
```c++
221+
struct ExampleTypeInterfaceTraits {
222+
struct Concept {
223+
virtual unsigned exampleInterfaceHook(Type type) const = 0;
224+
virtual unsigned exampleStaticInterfaceHook() const = 0;
225+
};
226+
227+
template <typename ConcreteType>
228+
struct Model : public Concept { /*...*/ };
229+
230+
/// Unlike `Model`, `FallbackModel` passes the type object through to the
231+
/// hook, making it accessible in the method body even if the method is not
232+
/// defined in the class itself and thus has no `this` access. ODS
233+
/// automatically generates this class for all interfaces.
234+
template <typename ConcreteType>
235+
struct FallbackModel : public Concept {
236+
unsigned exampleInterfaceHook(Type type) const override {
237+
getImpl()->exampleInterfaceHook(type);
238+
}
239+
unsigned exampleStaticInterfaceHook() const override {
240+
ConcreteType::exampleStaticInterfaceHook();
241+
}
242+
};
243+
244+
/// `ExternalModel` provides a place for default implementations of interface
245+
/// methods by explicitly separating the model class, which implements the
246+
/// interface, from the type class, for which the interface is being
247+
/// implemented. Default implementations can be then defined generically
248+
/// making use of `cast<ConcreteType>`. If `ConcreteType` does not provide
249+
/// the APIs required by the default implementation, custom implementations
250+
/// may use `FallbackModel` directly to override the default implementation.
251+
/// Being located in a class template, it never gets instantiated and does not
252+
/// lead to compilation errors. ODS automatically generates this class and
253+
/// places default method implementations in it.
254+
template <typename ConcreteModel, typename ConcreteType>
255+
struct ExternalModel : public FallbackModel<ConcreteModel> {
256+
unsigned exampleInterfaceHook(Type type) const override {
257+
// Default implementation can be provided here.
258+
return type.cast<ConcreteType>().callSomeTypeSpecificMethod();
259+
}
260+
};
261+
};
262+
```
263+
264+
External models can be provided for attribute and type interfaces by deriving
265+
either `FallbackModel` or `ExternalModel` and by registering the model class
266+
with the attribute or type class in a given context. Other contexts will not see
267+
the interface unless registered.
268+
269+
```c++
270+
/// External interface implementation for a concrete class. This does not
271+
/// require modifying the definition of the type class itself.
272+
struct ExternalModelExample
273+
: public ExampleTypeInterface::ExternalModel<ExternalModelExample,
274+
IntegerType> {
275+
static unsigned exampleStaticInterfaceHook() {
276+
// Implementation is provided here.
277+
return IntegerType::someStaticMethod();
278+
}
279+
280+
// No need to define `exampleInterfaceHook` that has a default implementation
281+
// in `ExternalModel`. But it can be overridden if desired.
282+
}
283+
284+
int main() {
285+
MLIRContext context;
286+
/* ... */;
287+
288+
// Register the interface model with the type in the given context before
289+
// using it. The dialect contaiing the type is expected to have been loaded
290+
// at this point.
291+
IntegerType::registerInterface<ExternalModelExample>(context);
292+
}
293+
```
294+
210295
#### Dialect Fallback for OpInterface
211296
212297
Some dialects have an open ecosystem and don't register all of the possible
@@ -215,9 +300,9 @@ implementing an `OpInterface` for these operation. When an operation isn't
215300
registered or does not provide an implementation for an interface, the query
216301
will fallback to the dialect itself.
217302
218-
A second model is used for such cases and automatically generated when
219-
using ODS (see below) with the name `FallbackModel`. This model can be implemented
220-
for a particular dialect:
303+
A second model is used for such cases and automatically generated when using ODS
304+
(see below) with the name `FallbackModel`. This model can be implemented for a
305+
particular dialect:
221306
222307
```c++
223308
// This is the implementation of a dialect fallback for `ExampleOpInterface`.

mlir/include/mlir/IR/AttributeSupport.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,24 @@ class AbstractAttribute {
5959
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
6060
typeID(typeID) {}
6161

62+
/// Give StorageUserBase access to the mutable lookup.
63+
template <typename ConcreteT, typename BaseT, typename StorageT,
64+
typename UniquerT, template <typename T> class... Traits>
65+
friend class detail::StorageUserBase;
66+
67+
/// Look up the specified abstract attribute in the MLIRContext and return a
68+
/// (mutable) pointer to it. Return a null pointer if the attribute could not
69+
/// be found in the context.
70+
static AbstractAttribute *lookupMutable(TypeID typeID, MLIRContext *context);
71+
6272
/// This is the dialect that this attribute was registered to.
63-
Dialect &dialect;
73+
const Dialect &dialect;
6474

6575
/// This is a collection of the interfaces registered to this attribute.
6676
detail::InterfaceMap interfaceMap;
6777

6878
/// The unique identifier of the derived Attribute class.
69-
TypeID typeID;
79+
const TypeID typeID;
7080
};
7181

7282
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/Attributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Attribute {
3131

3232
using ImplType = AttributeStorage;
3333
using ValueType = void;
34+
using AbstractType = AbstractAttribute;
3435

3536
constexpr Attribute() : impl(nullptr) {}
3637
/* implicit */ Attribute(const ImplType *impl)

mlir/include/mlir/IR/StorageUniquerSupport.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
8787
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
8888
}
8989

90+
/// Attach the given models as implementations of the corresponding interfaces
91+
/// for the concrete storage user class. The type must be registered with the
92+
/// context, i.e. the dialect to which the type belongs must be loaded. The
93+
/// call will abort otherwise.
94+
template <typename... IfaceModels>
95+
static void attachInterface(MLIRContext &context) {
96+
typename ConcreteT::AbstractType *abstract =
97+
ConcreteT::AbstractType::lookupMutable(TypeID::get<ConcreteT>(),
98+
&context);
99+
if (!abstract)
100+
llvm::report_fatal_error("Registering an interface for an attribute/type "
101+
"that is not itself registered.");
102+
abstract->interfaceMap.template insert<IfaceModels...>();
103+
}
104+
90105
/// Get or create a new ConcreteT instance within the ctx. This
91106
/// function is guaranteed to return a non null object and will assert if
92107
/// the arguments provided are invalid.

mlir/include/mlir/IR/TypeSupport.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,24 @@ class AbstractType {
6767
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
6868
typeID(typeID) {}
6969

70+
/// Give StorageUserBase access to the mutable lookup.
71+
template <typename ConcreteT, typename BaseT, typename StorageT,
72+
typename UniquerT, template <typename T> class... Traits>
73+
friend class detail::StorageUserBase;
74+
75+
/// Look up the specified abstract type in the MLIRContext and return a
76+
/// (mutable) pointer to it. Return a null pointer if the type could not
77+
/// be found in the context.
78+
static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context);
79+
7080
/// This is the dialect that this type was registered to.
71-
Dialect &dialect;
81+
const Dialect &dialect;
7282

7383
/// This is a collection of the interfaces registered to this type.
7484
detail::InterfaceMap interfaceMap;
7585

7686
/// The unique identifier of the derived Type class.
77-
TypeID typeID;
87+
const TypeID typeID;
7888
};
7989

8090
//===----------------------------------------------------------------------===//
@@ -105,11 +115,11 @@ class TypeStorage : public StorageUniquer::BaseStorage {
105115
/// Set the abstract type for this storage instance. This is used by the
106116
/// TypeUniquer when initializing a newly constructed type storage object.
107117
void initialize(const AbstractType &abstractTy) {
108-
abstractType = &abstractTy;
118+
abstractType = const_cast<AbstractType *>(&abstractTy);
109119
}
110120

111121
/// The abstract description for this type.
112-
const AbstractType *abstractType;
122+
AbstractType *abstractType;
113123
};
114124

115125
/// Default storage type for types that require no additional initialization or

mlir/include/mlir/IR/Types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class Type {
7979

8080
using ImplType = TypeStorage;
8181

82+
using AbstractType = AbstractType;
83+
8284
constexpr Type() : impl(nullptr) {}
8385
/* implicit */ Type(const ImplType *impl)
8486
: impl(const_cast<ImplType *>(impl)) {}

mlir/include/mlir/Support/InterfaceSupport.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class Interface : public BaseType {
7575
using FallbackModel = typename Traits::template FallbackModel<T>;
7676
using InterfaceBase =
7777
Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
78+
template <typename T, typename U>
79+
using ExternalModel = typename Traits::template ExternalModel<T, U>;
7880

7981
/// This is a special trait that registers a given interface with an object.
8082
template <typename ConcreteT>
@@ -139,6 +141,25 @@ struct FilterTypes {
139141
typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
140142
};
141143

144+
namespace {
145+
/// Type trait indicating whether all template arguments are
146+
/// trivially-destructible.
147+
template <typename... Args>
148+
struct all_trivially_destructible;
149+
150+
template <typename Arg, typename... Args>
151+
struct all_trivially_destructible<Arg, Args...> {
152+
static constexpr const bool value =
153+
std::is_trivially_destructible<Arg>::value &&
154+
all_trivially_destructible<Args...>::value;
155+
};
156+
157+
template <>
158+
struct all_trivially_destructible<> {
159+
static constexpr const bool value = true;
160+
};
161+
} // namespace
162+
142163
/// This class provides an efficient mapping between a given `Interface` type,
143164
/// and a particular implementation of its concept.
144165
class InterfaceMap {
@@ -198,6 +219,28 @@ class InterfaceMap {
198219
});
199220
}
200221

222+
/// Insert the given models as implementations of the corresponding interfaces
223+
/// for the concrete attribute class.
224+
template <typename... IfaceModels>
225+
void insert() {
226+
static_assert(all_trivially_destructible<IfaceModels...>::value,
227+
"interface models must be trivially destructible");
228+
std::pair<TypeID, void *> elements[] = {
229+
std::make_pair(IfaceModels::Interface::getInterfaceID(),
230+
new (malloc(sizeof(IfaceModels))) IfaceModels())...};
231+
// Insert directly into the right position to keep the interfaces sorted.
232+
for (auto &element : elements) {
233+
TypeID id = element.first;
234+
auto it =
235+
llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
236+
return compare(it.first, id);
237+
});
238+
if (it != interfaces.end() && it->first == id)
239+
llvm::report_fatal_error("Interface already registered");
240+
interfaces.insert(it, element);
241+
}
242+
}
243+
201244
private:
202245
/// Compare two TypeID instances by comparing the underlying pointer.
203246
static bool compare(TypeID lhs, TypeID rhs) {

mlir/lib/IR/MLIRContext.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class MLIRContextImpl {
306306
// Type uniquing
307307
//===--------------------------------------------------------------------===//
308308

309-
DenseMap<TypeID, const AbstractType *> registeredTypes;
309+
DenseMap<TypeID, AbstractType *> registeredTypes;
310310
StorageUniquer typeUniquer;
311311

312312
/// Cached Type Instances.
@@ -324,7 +324,7 @@ class MLIRContextImpl {
324324
// Attribute uniquing
325325
//===--------------------------------------------------------------------===//
326326

327-
DenseMap<TypeID, const AbstractAttribute *> registeredAttributes;
327+
DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
328328
StorageUniquer attributeUniquer;
329329

330330
/// Cached Attribute Instances.
@@ -666,12 +666,20 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
666666
/// Get the dialect that registered the attribute with the provided typeid.
667667
const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
668668
MLIRContext *context) {
669+
const AbstractAttribute *abstract = lookupMutable(typeID, context);
670+
if (!abstract)
671+
llvm::report_fatal_error("Trying to create an Attribute that was not "
672+
"registered in this MLIRContext.");
673+
return *abstract;
674+
}
675+
676+
AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
677+
MLIRContext *context) {
669678
auto &impl = context->getImpl();
670679
auto it = impl.registeredAttributes.find(typeID);
671680
if (it == impl.registeredAttributes.end())
672-
llvm::report_fatal_error("Trying to create an Attribute that was not "
673-
"registered in this MLIRContext.");
674-
return *it->second;
681+
return nullptr;
682+
return it->second;
675683
}
676684

677685
//===----------------------------------------------------------------------===//
@@ -733,12 +741,19 @@ AbstractOperation::AbstractOperation(
733741
//===----------------------------------------------------------------------===//
734742

735743
const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
744+
const AbstractType *type = lookupMutable(typeID, context);
745+
if (!type)
746+
llvm::report_fatal_error(
747+
"Trying to create a Type that was not registered in this MLIRContext.");
748+
return *type;
749+
}
750+
751+
AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
736752
auto &impl = context->getImpl();
737753
auto it = impl.registeredTypes.find(typeID);
738754
if (it == impl.registeredTypes.end())
739-
llvm::report_fatal_error(
740-
"Trying to create a Type that was not registered in this MLIRContext.");
741-
return *it->second;
755+
return nullptr;
756+
return it->second;
742757
}
743758

744759
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ set(LLVM_OPTIONAL_SOURCES
55
)
66

77
set(LLVM_TARGET_DEFINITIONS TestInterfaces.td)
8+
mlir_tablegen(TestAttrInterfaces.h.inc -gen-attr-interface-decls)
9+
mlir_tablegen(TestAttrInterfaces.cpp.inc -gen-attr-interface-defs)
810
mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
911
mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
1012
mlir_tablegen(TestOpInterfaces.h.inc -gen-op-interface-decls)

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
9393
// Tablegen Generated Definitions
9494
//===----------------------------------------------------------------------===//
9595

96+
#include "TestAttrInterfaces.cpp.inc"
97+
9698
#define GET_ATTRDEF_CLASSES
9799
#include "TestAttrDefs.cpp.inc"
98100

0 commit comments

Comments
 (0)