@@ -25,6 +25,22 @@ using namespace mlir;
2525using namespace mlir ::enzyme;
2626
2727namespace {
28+
29+ static mlir::Type batchType (mlir::Type type, int64_t width) {
30+ if (width == 1 )
31+ return type;
32+
33+ if (auto TT = dyn_cast<mlir::TensorType>(type)) {
34+ SmallVector<int64_t > shape;
35+ shape.reserve (TT.getShape ().size () + 1 );
36+ shape.push_back (width);
37+ shape.append (TT.getShape ().begin (), TT.getShape ().end ());
38+ return TT.clone (shape);
39+ }
40+
41+ return RankedTensorType::get ({width}, type);
42+ }
43+
2844class FloatTypeInterface
2945 : public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
3046 FloatType> {
@@ -44,12 +60,8 @@ class FloatTypeInterface
4460 return a;
4561 }
4662
47- Type getShadowType (Type self, unsigned width) const {
48- if (width > 1 ) {
49- return RankedTensorType::get ({width}, self);
50- } else {
51- return self;
52- }
63+ Type getShadowType (Type self, int64_t width) const {
64+ return batchType (self, width);
5365 }
5466
5567 bool isMutable (Type self) const { return false ; }
@@ -108,16 +120,8 @@ class TensorTypeInterface
108120 return added;
109121 }
110122
111- Type getShadowType (Type self, unsigned width) const {
112- if (width != 1 ) {
113- auto tenType = self.cast <TensorType>();
114- auto shape = tenType.getShape ();
115- SmallVector<int64_t , 4 > newShape;
116- newShape.push_back (width);
117- newShape.append (shape.begin (), shape.end ());
118- return RankedTensorType::get (newShape, tenType.getElementType ());
119- }
120- return self;
123+ Type getShadowType (Type self, int64_t width) const {
124+ return batchType (self, width);
121125 }
122126
123127 bool isMutable (Type self) const { return false ; }
@@ -148,9 +152,8 @@ class IntegerTypeInterface
148152 return a;
149153 }
150154
151- Type getShadowType (Type self, unsigned width) const {
152- assert (width == 1 && " unsupported width != 1" );
153- return self;
155+ Type getShadowType (Type self, int64_t width) const {
156+ return batchType (self, width);
154157 }
155158
156159 bool isMutable (Type self) const { return false ; }
@@ -182,9 +185,8 @@ class ComplexTypeInterface
182185 return builder.create <complex ::ConjOp>(loc, a)->getResult (0 );
183186 }
184187
185- Type getShadowType (Type self, unsigned width) const {
186- assert (width == 1 && " unsupported width != 1" );
187- return self;
188+ Type getShadowType (Type self, int64_t width) const {
189+ return batchType (self, width);
188190 }
189191
190192 bool isMutable (Type self) const { return false ; }
0 commit comments