@@ -41,12 +41,12 @@ static mlir::Type batchType(mlir::Type type, int64_t width) {
4141 return RankedTensorType::get ({width}, type);
4242}
4343
44- class FloatTypeInterface
45- : public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
46- FloatType > {
44+ template < typename ConcreteType>
45+ class FloatTypeInterface : public AutoDiffTypeInterface ::ExternalModel<
46+ FloatTypeInterface<ConcreteType>, ConcreteType > {
4747public:
4848 Value createNullValue (Type self, OpBuilder &builder, Location loc) const {
49- auto fltType = self.cast <FloatType >();
49+ auto fltType = self.cast <ConcreteType >();
5050 return builder.create <arith::ConstantFloatOp>(
5151 loc, APFloat (fltType.getFloatSemantics (), 0 ), fltType);
5252 }
@@ -200,10 +200,10 @@ class ComplexTypeInterface
200200void mlir::enzyme::registerBuiltinDialectAutoDiffInterface (
201201 DialectRegistry ®istry) {
202202 registry.addExtension (+[](MLIRContext *context, BuiltinDialect *) {
203- BFloat16Type::attachInterface<FloatTypeInterface>(*context);
204- Float16Type::attachInterface<FloatTypeInterface>(*context);
205- Float32Type::attachInterface<FloatTypeInterface>(*context);
206- Float64Type::attachInterface<FloatTypeInterface>(*context);
203+ BFloat16Type::attachInterface<FloatTypeInterface<BFloat16Type> >(*context);
204+ Float16Type::attachInterface<FloatTypeInterface<Float16Type> >(*context);
205+ Float32Type::attachInterface<FloatTypeInterface<Float32Type> >(*context);
206+ Float64Type::attachInterface<FloatTypeInterface<Float64Type> >(*context);
207207 IntegerType::attachInterface<IntegerTypeInterface<IntegerType>>(*context);
208208 IndexType::attachInterface<IntegerTypeInterface<IndexType>>(*context);
209209 UnrankedTensorType::attachInterface<TensorTypeInterface>(*context);
0 commit comments