Skip to content

Commit 5b330a9

Browse files
authored
mlir: Add Enzyme ops removal on structured control flow (#2200)
* mlir: Add Enzyme ops removal on structured control flow * format * use AutoDiffTypeInterface for batching * remove * add test with unknown number of iterations * don't push same value twice * tensor extract/insert * reserve the right size * better batchType * better comment
1 parent c759460 commit 5b330a9

File tree

12 files changed

+701
-148
lines changed

12 files changed

+701
-148
lines changed

enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ using namespace mlir;
2525
using namespace mlir::enzyme;
2626

2727
namespace {
28+
29+
static mlir::Type batchType(mlir::Type type, int64_t width) {
30+
if (width == 1)
31+
return type;
32+
33+
if (auto TT = dyn_cast<mlir::TensorType>(type)) {
34+
SmallVector<int64_t> shape;
35+
shape.reserve(TT.getShape().size() + 1);
36+
shape.push_back(width);
37+
shape.append(TT.getShape().begin(), TT.getShape().end());
38+
return TT.clone(shape);
39+
}
40+
41+
return RankedTensorType::get({width}, type);
42+
}
43+
2844
class FloatTypeInterface
2945
: public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
3046
FloatType> {
@@ -44,12 +60,8 @@ class FloatTypeInterface
4460
return a;
4561
}
4662

47-
Type getShadowType(Type self, unsigned width) const {
48-
if (width > 1) {
49-
return RankedTensorType::get({width}, self);
50-
} else {
51-
return self;
52-
}
63+
Type getShadowType(Type self, int64_t width) const {
64+
return batchType(self, width);
5365
}
5466

5567
bool isMutable(Type self) const { return false; }
@@ -108,16 +120,8 @@ class TensorTypeInterface
108120
return added;
109121
}
110122

111-
Type getShadowType(Type self, unsigned width) const {
112-
if (width != 1) {
113-
auto tenType = self.cast<TensorType>();
114-
auto shape = tenType.getShape();
115-
SmallVector<int64_t, 4> newShape;
116-
newShape.push_back(width);
117-
newShape.append(shape.begin(), shape.end());
118-
return RankedTensorType::get(newShape, tenType.getElementType());
119-
}
120-
return self;
123+
Type getShadowType(Type self, int64_t width) const {
124+
return batchType(self, width);
121125
}
122126

123127
bool isMutable(Type self) const { return false; }
@@ -148,9 +152,8 @@ class IntegerTypeInterface
148152
return a;
149153
}
150154

151-
Type getShadowType(Type self, unsigned width) const {
152-
assert(width == 1 && "unsupported width != 1");
153-
return self;
155+
Type getShadowType(Type self, int64_t width) const {
156+
return batchType(self, width);
154157
}
155158

156159
bool isMutable(Type self) const { return false; }
@@ -182,9 +185,8 @@ class ComplexTypeInterface
182185
return builder.create<complex::ConjOp>(loc, a)->getResult(0);
183186
}
184187

185-
Type getShadowType(Type self, unsigned width) const {
186-
assert(width == 1 && "unsupported width != 1");
187-
return self;
188+
Type getShadowType(Type self, int64_t width) const {
189+
return batchType(self, width);
188190
}
189191

190192
bool isMutable(Type self) const { return false; }

0 commit comments

Comments
 (0)