Skip to content

Commit df197be

Browse files
authored
Drive enzyme ops removal using pattern rewriter (#2229)
* run eopt for cachefunction test * deterministic RemoveUnusedEnzymeOps * add PatternRewriter to EnzymeOpsRemoverOpInterface * builder -> rewriter * docs * workaround driver problems * revert test changes * use a post-order ordered driver * remove removeOpsWithinBlock * remove log * update worklist * credits * remove unused pattern * report failures
1 parent 0e5fa4a commit df197be

File tree

6 files changed

+274
-123
lines changed

6 files changed

+274
-123
lines changed

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 90 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "mlir/IR/Types.h"
2626
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2727
#include "mlir/Support/LogicalResult.h"
28-
#include "llvm/ADT/DenseSet.h"
2928
#include "llvm/ADT/STLExtras.h"
3029
#include <functional>
3130

@@ -67,15 +66,13 @@ struct ForOpEnzymeOpsRemover
6766
: public EnzymeOpsRemoverOpInterface::ExternalModel<ForOpEnzymeOpsRemover,
6867
scf::ForOp> {
6968

70-
LogicalResult removeEnzymeOps(Operation *op) const {
69+
LogicalResult removeEnzymeOps(Operation *op,
70+
PatternRewriter &rewriter) const {
7171
auto forOp = cast<scf::ForOp>(op);
7272
scf::ForOp otherForOp; // where caches pops are
7373

74-
if (removeOpsWithinBlock(forOp.getBody()).failed())
75-
return failure();
76-
7774
// Gradients whose values need to be passed as iteration variables.
78-
llvm::SmallDenseSet<Value> updatedGradients;
75+
llvm::SetVector<Value> updatedGradients;
7976

8077
llvm::MapVector<Value, CacheInfo> cachesMap;
8178

@@ -92,7 +89,7 @@ struct ForOpEnzymeOpsRemover
9289

9390
Value pushedValue = info.pushedValue();
9491
if (cachesMap.contains(pushedValue)) {
95-
info = info.merge(cachesMap.lookup(pushedValue));
92+
info = info.merge(cachesMap.lookup(pushedValue), rewriter);
9693
}
9794
cachesMap[pushedValue] = info;
9895

@@ -110,43 +107,42 @@ struct ForOpEnzymeOpsRemover
110107
if (updatedGradients.empty() && caches.empty())
111108
return success();
112109

113-
OpBuilder builder(forOp);
114110
for (auto &it : *body) {
115111
Operation *op = &it;
116112

117113
auto getOp = dyn_cast<enzyme::GetOp>(op);
118114
if (!getOp || updatedGradients.contains(getOp.getGradient()))
119115
continue;
120116

121-
auto outerGet = builder.create<enzyme::GetOp>(
117+
auto outerGet = rewriter.create<enzyme::GetOp>(
122118
getOp->getLoc(),
123119
cast<enzyme::GradientType>(getOp.getResult().getType()).getBasetype(),
124120
getOp.getGradient());
125121

126-
getOp.getResult().replaceAllUsesWith(outerGet.getResult());
127-
getOp->erase();
122+
rewriter.replaceAllUsesWith(getOp.getResult(), outerGet.getResult());
123+
rewriter.eraseOp(getOp);
128124
}
129125

130126
auto term = body->getTerminator();
131127

132128
SmallVector<Value> newOperands(forOp.getInitArgs());
133129
for (auto grad : updatedGradients) {
134130
auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
135-
auto outerGet = builder.create<enzyme::GetOp>(grad.getLoc(), Ty, grad);
131+
auto outerGet = rewriter.create<enzyme::GetOp>(grad.getLoc(), Ty, grad);
136132

137133
newOperands.push_back(outerGet.getResult());
138134
auto newArg = body->addArgument(Ty, grad.getLoc());
139135

140136
{
141-
OpBuilder::InsertionGuard guard(builder);
137+
OpBuilder::InsertionGuard guard(rewriter);
142138

143-
builder.setInsertionPointToStart(body);
144-
builder.create<enzyme::SetOp>(grad.getLoc(), grad, newArg);
139+
rewriter.setInsertionPointToStart(body);
140+
rewriter.create<enzyme::SetOp>(grad.getLoc(), grad, newArg);
145141

146-
builder.setInsertionPoint(term);
142+
rewriter.setInsertionPoint(term);
147143

148144
auto outputVal =
149-
builder.create<enzyme::GetOp>(grad.getLoc(), Ty, grad).getResult();
145+
rewriter.create<enzyme::GetOp>(grad.getLoc(), Ty, grad).getResult();
150146
term->insertOperands(term->getNumOperands(), ValueRange(outputVal));
151147
}
152148
}
@@ -159,45 +155,45 @@ struct ForOpEnzymeOpsRemover
159155
inductionVariable = body->getArgument(0);
160156
}
161157

162-
for (auto info : caches) {
158+
for (auto &info : caches) {
163159
Value cache = info.initOp.getResult();
164160

165161
// push does not depend on a value inside the loop, we can hoist the
166162
// push/pop before the for loops.
167-
if (info.pushedValue().getParentRegion() != forOp->getRegion(0)) {
168-
auto newPush = builder.create<enzyme::PushOp>(cache.getLoc(), cache,
169-
info.pushedValue());
170-
info.pushOp->erase();
163+
if (info.pushedValue().getParentRegion() != forOp.getRegion()) {
164+
auto newPush = rewriter.create<enzyme::PushOp>(cache.getLoc(), cache,
165+
info.pushedValue());
166+
rewriter.eraseOp(info.pushOp);
171167
info.pushOp = newPush;
172168

173169
{
174-
OpBuilder::InsertionGuard guard(builder);
175-
builder.setInsertionPoint(info.popOp->getParentOp());
170+
OpBuilder::InsertionGuard guard(rewriter);
171+
rewriter.setInsertionPoint(info.popOp->getParentOp());
176172

177173
auto popVal = info.popOp.getResult();
178-
auto newPop = builder.create<enzyme::PopOp>(cache.getLoc(),
179-
popVal.getType(), cache);
180-
popVal.replaceAllUsesWith(newPop.getResult());
181-
info.popOp->erase();
174+
auto newPop = rewriter.create<enzyme::PopOp>(cache.getLoc(),
175+
popVal.getType(), cache);
176+
rewriter.replaceAllUsesWith(popVal, newPop.getResult());
177+
rewriter.eraseOp(info.popOp);
182178
info.popOp = newPop;
183179
}
184180

185181
continue;
186182
}
187183

188184
if (!inductionVariable) {
189-
Value zero = builder.create<arith::ConstantOp>(forOp->getLoc(),
190-
builder.getIndexAttr(0));
185+
Value zero = rewriter.create<arith::ConstantOp>(
186+
forOp->getLoc(), rewriter.getIndexAttr(0));
191187
newOperands.push_back(zero);
192188

193189
inductionVariable = body->addArgument(zero.getType(), forOp->getLoc());
194190
{
195-
OpBuilder::InsertionGuard guard(builder);
196-
builder.setInsertionPoint(term);
191+
OpBuilder::InsertionGuard guard(rewriter);
192+
rewriter.setInsertionPoint(term);
197193

198-
auto one = builder.create<arith::ConstantOp>(forOp->getLoc(),
199-
builder.getIndexAttr(1));
200-
auto newInductionVar = builder.create<arith::AddIOp>(
194+
auto one = rewriter.create<arith::ConstantOp>(
195+
forOp->getLoc(), rewriter.getIndexAttr(1));
196+
auto newInductionVar = rewriter.create<arith::AddIOp>(
201197
forOp->getLoc(), inductionVariable, one);
202198
term->insertOperands(term->getNumOperands(),
203199
ValueRange(newInductionVar));
@@ -215,25 +211,25 @@ struct ForOpEnzymeOpsRemover
215211
for (auto it : llvm::enumerate(newType.getShape())) {
216212
if (ShapedType::isDynamic(it.value())) {
217213
if (it.index() == 0)
218-
dynamicDims.push_back(getNumberOfIterations(builder, forOp));
214+
dynamicDims.push_back(getNumberOfIterations(rewriter, forOp));
219215
else
220216
return failure(); // TODO: find dynamic dims within the body.
221217
}
222218
}
223219

224-
Value initValue = builder.create<tensor::EmptyOp>(info.initOp->getLoc(),
225-
newType, dynamicDims);
220+
Value initValue = rewriter.create<tensor::EmptyOp>(info.initOp->getLoc(),
221+
newType, dynamicDims);
226222

227223
// cast<AutoDiffTypeInterface>(newType).createNullValue(
228-
// builder, info.initOp->getLoc());
224+
// rewriter, info.initOp->getLoc());
229225

230226
newOperands.push_back(initValue);
231227

232228
auto cacheValue = body->addArgument(newType, info.pushOp->getLoc());
233229

234230
{
235-
OpBuilder::InsertionGuard guard(builder);
236-
builder.setInsertionPoint(info.pushOp);
231+
OpBuilder::InsertionGuard guard(rewriter);
232+
rewriter.setInsertionPoint(info.pushOp);
237233

238234
// TODO: if type is tensor, use insert_slice instead
239235
Value newCacheValue;
@@ -250,14 +246,14 @@ struct ForOpEnzymeOpsRemover
250246

251247
SmallVector<int64_t> strides(shape.size() + 1, 1);
252248

253-
newCacheValue = builder.create<tensor::InsertSliceOp>(
249+
newCacheValue = rewriter.create<tensor::InsertSliceOp>(
254250
info.pushOp->getLoc(), info.pushOp.getValue(), cacheValue,
255251
ValueRange(inductionVariable), ValueRange(), ValueRange(),
256-
builder.getDenseI64ArrayAttr(offsets),
257-
builder.getDenseI64ArrayAttr(sizes),
258-
builder.getDenseI64ArrayAttr(strides));
252+
rewriter.getDenseI64ArrayAttr(offsets),
253+
rewriter.getDenseI64ArrayAttr(sizes),
254+
rewriter.getDenseI64ArrayAttr(strides));
259255
} else {
260-
newCacheValue = builder.create<tensor::InsertOp>(
256+
newCacheValue = rewriter.create<tensor::InsertOp>(
261257
info.pushOp->getLoc(), info.pushOp.getValue(), cacheValue,
262258
inductionVariable);
263259
}
@@ -267,72 +263,81 @@ struct ForOpEnzymeOpsRemover
267263
}
268264

269265
auto numInitArgs = forOp.getInitArgs().size();
270-
auto newFor = builder.create<scf::ForOp>(
266+
auto newFor = rewriter.create<scf::ForOp>(
271267
op->getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
272268
forOp.getStep(), newOperands);
273269

274270
newFor.getRegion().takeBody(forOp.getRegion());
275271

272+
for (auto &&[res, newRes] :
273+
llvm::zip(forOp->getResults(), newFor->getResults())) {
274+
rewriter.replaceAllUsesWith(res, newRes);
275+
}
276+
277+
rewriter.eraseOp(forOp);
278+
forOp = newFor;
279+
rewriter.setInsertionPointAfter(forOp);
280+
276281
unsigned resultIdx = numInitArgs;
277282
for (auto grad : updatedGradients) {
278283
// set the updated gradient after the new for op.
279-
OpBuilder::InsertionGuard guard(builder);
280-
builder.create<enzyme::SetOp>(grad.getLoc(), grad,
281-
newFor->getResult(resultIdx));
284+
OpBuilder::InsertionGuard guard(rewriter);
285+
rewriter.create<enzyme::SetOp>(grad.getLoc(), grad,
286+
newFor->getResult(resultIdx));
282287
++resultIdx;
283288
}
284289

285-
if (inductionVariable && caches.size()) {
290+
if (inductionVariable && !caches.empty()) {
286291
if (isa<BlockArgument>(inductionVariable) &&
287292
cast<BlockArgument>(inductionVariable).getArgNumber() != 0)
288293
resultIdx++;
289294

290-
OpBuilder::InsertionGuard guard(builder);
291-
builder.setInsertionPoint(otherForOp);
295+
OpBuilder::InsertionGuard guard(rewriter);
296+
rewriter.setInsertionPoint(otherForOp);
292297
SmallVector<Value> operands(otherForOp.getInitArgs().begin(),
293298
otherForOp.getInitArgs().end());
294299
operands.push_back(numIters.has_value()
295-
? builder.create<arith::ConstantOp>(
300+
? rewriter.create<arith::ConstantOp>(
296301
otherForOp->getLoc(),
297-
builder.getIndexAttr(numIters.value() - 1))
298-
: getNumberOfIterations(builder, forOp));
302+
rewriter.getIndexAttr(numIters.value() - 1))
303+
: getNumberOfIterations(rewriter, forOp));
299304

300305
Block *otherBody = otherForOp.getBody();
301306
Value otherInductionVariable =
302-
otherBody->addArgument(builder.getIndexType(), otherForOp->getLoc());
307+
otherBody->addArgument(rewriter.getIndexType(), otherForOp->getLoc());
303308
auto otherTerm = otherBody->getTerminator();
304309

305-
builder.setInsertionPoint(otherTerm);
310+
rewriter.setInsertionPoint(otherTerm);
306311

307312
otherInductionVariable =
308-
builder
313+
rewriter
309314
.create<arith::SubIOp>(
310315
otherForOp->getLoc(), otherInductionVariable,
311-
builder
316+
rewriter
312317
.create<arith::ConstantOp>(otherForOp->getLoc(),
313-
builder.getIndexAttr(1))
318+
rewriter.getIndexAttr(1))
314319
.getResult())
315320
.getResult();
316321
otherTerm->insertOperands(otherTerm->getNumOperands(),
317322
ValueRange(otherInductionVariable));
318323

319-
builder.setInsertionPoint(otherForOp);
320-
auto newOtherForOp = builder.create<scf::ForOp>(
324+
rewriter.setInsertionPoint(otherForOp);
325+
auto newOtherForOp = rewriter.create<scf::ForOp>(
321326
otherForOp->getLoc(), otherForOp.getLowerBound(),
322327
otherForOp.getUpperBound(), otherForOp.getStep(), operands);
323328

324329
for (auto &&[res, newRes] :
325330
llvm::zip(otherForOp->getResults(), newOtherForOp->getResults())) {
326-
res.replaceAllUsesWith(newRes);
331+
rewriter.replaceAllUsesWith(res, newRes);
327332
}
328333
newOtherForOp.getRegion().takeBody(otherForOp.getRegion());
329334

330-
otherForOp->erase();
335+
rewriter.eraseOp(otherForOp);
331336
otherForOp = newOtherForOp;
332337
}
333338

334-
for (auto info : caches) {
335-
if (info.pushedValue().getParentRegion() != newFor->getRegion(0))
339+
for (auto &info : caches) {
340+
if (info.pushedValue().getParentRegion() != newFor.getRegion())
336341
continue;
337342

338343
Value cache = info.initOp.getResult();
@@ -341,34 +346,34 @@ struct ForOpEnzymeOpsRemover
341346
info.cachedType().cast<AutoDiffTypeInterface>().getShadowType(
342347
numIters.value_or(ShapedType::kDynamic));
343348
enzyme::InitOp newInit = ({
344-
OpBuilder::InsertionGuard guard(builder);
345-
builder.setInsertionPoint(info.initOp);
349+
OpBuilder::InsertionGuard guard(rewriter);
350+
rewriter.setInsertionPoint(info.initOp);
346351

347-
builder.create<enzyme::InitOp>(
352+
rewriter.create<enzyme::InitOp>(
348353
info.initOp->getLoc(),
349354
enzyme::CacheType::get(cache.getContext(), newType));
350355
});
351356
info.pushOp = ({
352-
OpBuilder::InsertionGuard guard(builder);
353-
builder.setInsertionPointAfter(newFor);
354-
auto newPush = builder.create<enzyme::PushOp>(
357+
OpBuilder::InsertionGuard guard(rewriter);
358+
rewriter.setInsertionPointAfter(newFor);
359+
auto newPush = rewriter.create<enzyme::PushOp>(
355360
cache.getLoc(), newInit.getResult(), newFor->getResult(resultIdx));
356-
info.pushOp->erase();
361+
rewriter.eraseOp(info.pushOp);
357362
newPush;
358363
});
359364

360365
resultIdx++;
361366

362367
{
363-
OpBuilder::InsertionGuard guard(builder);
368+
OpBuilder::InsertionGuard guard(rewriter);
364369

365-
builder.setInsertionPoint(otherForOp);
370+
rewriter.setInsertionPoint(otherForOp);
366371

367-
auto popNewValue = builder.create<enzyme::PopOp>(
372+
auto popNewValue = rewriter.create<enzyme::PopOp>(
368373
info.popOp->getLoc(), newType, newInit.getResult());
369374

370375
Block *popBody = otherForOp.getBody();
371-
builder.setInsertionPoint(info.popOp);
376+
rewriter.setInsertionPoint(info.popOp);
372377

373378
Value newInductionVariable =
374379
popBody->getArgument(popBody->getNumArguments() - 1);
@@ -387,29 +392,27 @@ struct ForOpEnzymeOpsRemover
387392
SmallVector<int64_t> strides(shape.size() + 1, 1);
388393

389394
popValue =
390-
builder
395+
rewriter
391396
.create<tensor::ExtractSliceOp>(
392397
info.popOp->getLoc(), TT, popNewValue,
393398
ValueRange(newInductionVariable), ValueRange(),
394-
ValueRange(), builder.getDenseI64ArrayAttr(offsets),
395-
builder.getDenseI64ArrayAttr(sizes),
396-
builder.getDenseI64ArrayAttr(strides))
399+
ValueRange(), rewriter.getDenseI64ArrayAttr(offsets),
400+
rewriter.getDenseI64ArrayAttr(sizes),
401+
rewriter.getDenseI64ArrayAttr(strides))
397402
.getResult();
398403
} else {
399404
popValue =
400-
builder
405+
rewriter
401406
.create<tensor::ExtractOp>(info.popOp->getLoc(), popNewValue,
402407
newInductionVariable)
403408
.getResult();
404409
}
405410

406-
info.popOp.getResult().replaceAllUsesWith(popValue);
407-
info.popOp->erase();
411+
rewriter.replaceAllUsesWith(info.popOp.getResult(), popValue);
412+
rewriter.eraseOp(info.popOp);
408413
}
409414
}
410415

411-
forOp->erase();
412-
413416
return success();
414417
}
415418
};

enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace mlir {
2121
class OpBuilder;
2222
class Operation;
2323
class IRMapping;
24+
class PatternRewriter;
2425

2526
namespace enzyme {
2627

0 commit comments

Comments
 (0)