2525#include " mlir/IR/Types.h"
2626#include " mlir/Interfaces/ControlFlowInterfaces.h"
2727#include " mlir/Support/LogicalResult.h"
28- #include " llvm/ADT/DenseSet.h"
2928#include " llvm/ADT/STLExtras.h"
3029#include < functional>
3130
@@ -67,15 +66,13 @@ struct ForOpEnzymeOpsRemover
6766 : public EnzymeOpsRemoverOpInterface::ExternalModel<ForOpEnzymeOpsRemover,
6867 scf::ForOp> {
6968
70- LogicalResult removeEnzymeOps (Operation *op) const {
69+ LogicalResult removeEnzymeOps (Operation *op,
70+ PatternRewriter &rewriter) const {
7171 auto forOp = cast<scf::ForOp>(op);
7272 scf::ForOp otherForOp; // where caches pops are
7373
74- if (removeOpsWithinBlock (forOp.getBody ()).failed ())
75- return failure ();
76-
7774 // Gradients whose values need to be passed as iteration variables.
78- llvm::SmallDenseSet <Value> updatedGradients;
75+ llvm::SetVector <Value> updatedGradients;
7976
8077 llvm::MapVector<Value, CacheInfo> cachesMap;
8178
@@ -92,7 +89,7 @@ struct ForOpEnzymeOpsRemover
9289
9390 Value pushedValue = info.pushedValue ();
9491 if (cachesMap.contains (pushedValue)) {
95- info = info.merge (cachesMap.lookup (pushedValue));
92+ info = info.merge (cachesMap.lookup (pushedValue), rewriter );
9693 }
9794 cachesMap[pushedValue] = info;
9895
@@ -110,43 +107,42 @@ struct ForOpEnzymeOpsRemover
110107 if (updatedGradients.empty () && caches.empty ())
111108 return success ();
112109
113- OpBuilder builder (forOp);
114110 for (auto &it : *body) {
115111 Operation *op = ⁢
116112
117113 auto getOp = dyn_cast<enzyme::GetOp>(op);
118114 if (!getOp || updatedGradients.contains (getOp.getGradient ()))
119115 continue ;
120116
121- auto outerGet = builder .create <enzyme::GetOp>(
117+ auto outerGet = rewriter .create <enzyme::GetOp>(
122118 getOp->getLoc (),
123119 cast<enzyme::GradientType>(getOp.getResult ().getType ()).getBasetype (),
124120 getOp.getGradient ());
125121
126- getOp.getResult (). replaceAllUsesWith ( outerGet.getResult ());
127- getOp-> erase ( );
122+ rewriter. replaceAllUsesWith ( getOp.getResult (), outerGet.getResult ());
123+ rewriter. eraseOp (getOp );
128124 }
129125
130126 auto term = body->getTerminator ();
131127
132128 SmallVector<Value> newOperands (forOp.getInitArgs ());
133129 for (auto grad : updatedGradients) {
134130 auto Ty = cast<enzyme::GradientType>(grad.getType ()).getBasetype ();
135- auto outerGet = builder .create <enzyme::GetOp>(grad.getLoc (), Ty, grad);
131+ auto outerGet = rewriter .create <enzyme::GetOp>(grad.getLoc (), Ty, grad);
136132
137133 newOperands.push_back (outerGet.getResult ());
138134 auto newArg = body->addArgument (Ty, grad.getLoc ());
139135
140136 {
141- OpBuilder::InsertionGuard guard (builder );
137+ OpBuilder::InsertionGuard guard (rewriter );
142138
143- builder .setInsertionPointToStart (body);
144- builder .create <enzyme::SetOp>(grad.getLoc (), grad, newArg);
139+ rewriter .setInsertionPointToStart (body);
140+ rewriter .create <enzyme::SetOp>(grad.getLoc (), grad, newArg);
145141
146- builder .setInsertionPoint (term);
142+ rewriter .setInsertionPoint (term);
147143
148144 auto outputVal =
149- builder .create <enzyme::GetOp>(grad.getLoc (), Ty, grad).getResult ();
145+ rewriter .create <enzyme::GetOp>(grad.getLoc (), Ty, grad).getResult ();
150146 term->insertOperands (term->getNumOperands (), ValueRange (outputVal));
151147 }
152148 }
@@ -159,45 +155,45 @@ struct ForOpEnzymeOpsRemover
159155 inductionVariable = body->getArgument (0 );
160156 }
161157
162- for (auto info : caches) {
158+ for (auto & info : caches) {
163159 Value cache = info.initOp .getResult ();
164160
165161 // push does not depend on a value inside the loop, we can hoist the
166162 // push/pop before the for loops.
167- if (info.pushedValue ().getParentRegion () != forOp-> getRegion (0 )) {
168- auto newPush = builder .create <enzyme::PushOp>(cache.getLoc (), cache,
169- info.pushedValue ());
170- info.pushOp -> erase ( );
163+ if (info.pushedValue ().getParentRegion () != forOp. getRegion ()) {
164+ auto newPush = rewriter .create <enzyme::PushOp>(cache.getLoc (), cache,
165+ info.pushedValue ());
166+ rewriter. eraseOp ( info.pushOp );
171167 info.pushOp = newPush;
172168
173169 {
174- OpBuilder::InsertionGuard guard (builder );
175- builder .setInsertionPoint (info.popOp ->getParentOp ());
170+ OpBuilder::InsertionGuard guard (rewriter );
171+ rewriter .setInsertionPoint (info.popOp ->getParentOp ());
176172
177173 auto popVal = info.popOp .getResult ();
178- auto newPop = builder .create <enzyme::PopOp>(cache.getLoc (),
179- popVal.getType (), cache);
180- popVal .replaceAllUsesWith (newPop.getResult ());
181- info.popOp -> erase ( );
174+ auto newPop = rewriter .create <enzyme::PopOp>(cache.getLoc (),
175+ popVal.getType (), cache);
176+ rewriter .replaceAllUsesWith (popVal, newPop.getResult ());
177+ rewriter. eraseOp ( info.popOp );
182178 info.popOp = newPop;
183179 }
184180
185181 continue ;
186182 }
187183
188184 if (!inductionVariable) {
189- Value zero = builder .create <arith::ConstantOp>(forOp-> getLoc (),
190- builder .getIndexAttr (0 ));
185+ Value zero = rewriter .create <arith::ConstantOp>(
186+ forOp-> getLoc (), rewriter .getIndexAttr (0 ));
191187 newOperands.push_back (zero);
192188
193189 inductionVariable = body->addArgument (zero.getType (), forOp->getLoc ());
194190 {
195- OpBuilder::InsertionGuard guard (builder );
196- builder .setInsertionPoint (term);
191+ OpBuilder::InsertionGuard guard (rewriter );
192+ rewriter .setInsertionPoint (term);
197193
198- auto one = builder .create <arith::ConstantOp>(forOp-> getLoc (),
199- builder .getIndexAttr (1 ));
200- auto newInductionVar = builder .create <arith::AddIOp>(
194+ auto one = rewriter .create <arith::ConstantOp>(
195+ forOp-> getLoc (), rewriter .getIndexAttr (1 ));
196+ auto newInductionVar = rewriter .create <arith::AddIOp>(
201197 forOp->getLoc (), inductionVariable, one);
202198 term->insertOperands (term->getNumOperands (),
203199 ValueRange (newInductionVar));
@@ -215,25 +211,25 @@ struct ForOpEnzymeOpsRemover
215211 for (auto it : llvm::enumerate (newType.getShape ())) {
216212 if (ShapedType::isDynamic (it.value ())) {
217213 if (it.index () == 0 )
218- dynamicDims.push_back (getNumberOfIterations (builder , forOp));
214+ dynamicDims.push_back (getNumberOfIterations (rewriter , forOp));
219215 else
220216 return failure (); // TODO: find dynamic dims within the body.
221217 }
222218 }
223219
224- Value initValue = builder .create <tensor::EmptyOp>(info.initOp ->getLoc (),
225- newType, dynamicDims);
220+ Value initValue = rewriter .create <tensor::EmptyOp>(info.initOp ->getLoc (),
221+ newType, dynamicDims);
226222
227223 // cast<AutoDiffTypeInterface>(newType).createNullValue(
228- // builder , info.initOp->getLoc());
224+ // rewriter , info.initOp->getLoc());
229225
230226 newOperands.push_back (initValue);
231227
232228 auto cacheValue = body->addArgument (newType, info.pushOp ->getLoc ());
233229
234230 {
235- OpBuilder::InsertionGuard guard (builder );
236- builder .setInsertionPoint (info.pushOp );
231+ OpBuilder::InsertionGuard guard (rewriter );
232+ rewriter .setInsertionPoint (info.pushOp );
237233
238234 // TODO: if type is tensor, use insert_slice instead
239235 Value newCacheValue;
@@ -250,14 +246,14 @@ struct ForOpEnzymeOpsRemover
250246
251247 SmallVector<int64_t > strides (shape.size () + 1 , 1 );
252248
253- newCacheValue = builder .create <tensor::InsertSliceOp>(
249+ newCacheValue = rewriter .create <tensor::InsertSliceOp>(
254250 info.pushOp ->getLoc (), info.pushOp .getValue (), cacheValue,
255251 ValueRange (inductionVariable), ValueRange (), ValueRange (),
256- builder .getDenseI64ArrayAttr (offsets),
257- builder .getDenseI64ArrayAttr (sizes),
258- builder .getDenseI64ArrayAttr (strides));
252+ rewriter .getDenseI64ArrayAttr (offsets),
253+ rewriter .getDenseI64ArrayAttr (sizes),
254+ rewriter .getDenseI64ArrayAttr (strides));
259255 } else {
260- newCacheValue = builder .create <tensor::InsertOp>(
256+ newCacheValue = rewriter .create <tensor::InsertOp>(
261257 info.pushOp ->getLoc (), info.pushOp .getValue (), cacheValue,
262258 inductionVariable);
263259 }
@@ -267,72 +263,81 @@ struct ForOpEnzymeOpsRemover
267263 }
268264
269265 auto numInitArgs = forOp.getInitArgs ().size ();
270- auto newFor = builder .create <scf::ForOp>(
266+ auto newFor = rewriter .create <scf::ForOp>(
271267 op->getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
272268 forOp.getStep (), newOperands);
273269
274270 newFor.getRegion ().takeBody (forOp.getRegion ());
275271
272+ for (auto &&[res, newRes] :
273+ llvm::zip (forOp->getResults (), newFor->getResults ())) {
274+ rewriter.replaceAllUsesWith (res, newRes);
275+ }
276+
277+ rewriter.eraseOp (forOp);
278+ forOp = newFor;
279+ rewriter.setInsertionPointAfter (forOp);
280+
276281 unsigned resultIdx = numInitArgs;
277282 for (auto grad : updatedGradients) {
278283 // set the updated gradient after the new for op.
279- OpBuilder::InsertionGuard guard (builder );
280- builder .create <enzyme::SetOp>(grad.getLoc (), grad,
281- newFor->getResult (resultIdx));
284+ OpBuilder::InsertionGuard guard (rewriter );
285+ rewriter .create <enzyme::SetOp>(grad.getLoc (), grad,
286+ newFor->getResult (resultIdx));
282287 ++resultIdx;
283288 }
284289
285- if (inductionVariable && caches.size ()) {
290+ if (inductionVariable && ! caches.empty ()) {
286291 if (isa<BlockArgument>(inductionVariable) &&
287292 cast<BlockArgument>(inductionVariable).getArgNumber () != 0 )
288293 resultIdx++;
289294
290- OpBuilder::InsertionGuard guard (builder );
291- builder .setInsertionPoint (otherForOp);
295+ OpBuilder::InsertionGuard guard (rewriter );
296+ rewriter .setInsertionPoint (otherForOp);
292297 SmallVector<Value> operands (otherForOp.getInitArgs ().begin (),
293298 otherForOp.getInitArgs ().end ());
294299 operands.push_back (numIters.has_value ()
295- ? builder .create <arith::ConstantOp>(
300+ ? rewriter .create <arith::ConstantOp>(
296301 otherForOp->getLoc (),
297- builder .getIndexAttr (numIters.value () - 1 ))
298- : getNumberOfIterations (builder , forOp));
302+ rewriter .getIndexAttr (numIters.value () - 1 ))
303+ : getNumberOfIterations (rewriter , forOp));
299304
300305 Block *otherBody = otherForOp.getBody ();
301306 Value otherInductionVariable =
302- otherBody->addArgument (builder .getIndexType (), otherForOp->getLoc ());
307+ otherBody->addArgument (rewriter .getIndexType (), otherForOp->getLoc ());
303308 auto otherTerm = otherBody->getTerminator ();
304309
305- builder .setInsertionPoint (otherTerm);
310+ rewriter .setInsertionPoint (otherTerm);
306311
307312 otherInductionVariable =
308- builder
313+ rewriter
309314 .create <arith::SubIOp>(
310315 otherForOp->getLoc (), otherInductionVariable,
311- builder
316+ rewriter
312317 .create <arith::ConstantOp>(otherForOp->getLoc (),
313- builder .getIndexAttr (1 ))
318+ rewriter .getIndexAttr (1 ))
314319 .getResult ())
315320 .getResult ();
316321 otherTerm->insertOperands (otherTerm->getNumOperands (),
317322 ValueRange (otherInductionVariable));
318323
319- builder .setInsertionPoint (otherForOp);
320- auto newOtherForOp = builder .create <scf::ForOp>(
324+ rewriter .setInsertionPoint (otherForOp);
325+ auto newOtherForOp = rewriter .create <scf::ForOp>(
321326 otherForOp->getLoc (), otherForOp.getLowerBound (),
322327 otherForOp.getUpperBound (), otherForOp.getStep (), operands);
323328
324329 for (auto &&[res, newRes] :
325330 llvm::zip (otherForOp->getResults (), newOtherForOp->getResults ())) {
326- res .replaceAllUsesWith (newRes);
331+ rewriter .replaceAllUsesWith (res, newRes);
327332 }
328333 newOtherForOp.getRegion ().takeBody (otherForOp.getRegion ());
329334
330- otherForOp-> erase ( );
335+ rewriter. eraseOp (otherForOp );
331336 otherForOp = newOtherForOp;
332337 }
333338
334- for (auto info : caches) {
335- if (info.pushedValue ().getParentRegion () != newFor-> getRegion (0 ))
339+ for (auto & info : caches) {
340+ if (info.pushedValue ().getParentRegion () != newFor. getRegion ())
336341 continue ;
337342
338343 Value cache = info.initOp .getResult ();
@@ -341,34 +346,34 @@ struct ForOpEnzymeOpsRemover
341346 info.cachedType ().cast <AutoDiffTypeInterface>().getShadowType (
342347 numIters.value_or (ShapedType::kDynamic ));
343348 enzyme::InitOp newInit = ({
344- OpBuilder::InsertionGuard guard (builder );
345- builder .setInsertionPoint (info.initOp );
349+ OpBuilder::InsertionGuard guard (rewriter );
350+ rewriter .setInsertionPoint (info.initOp );
346351
347- builder .create <enzyme::InitOp>(
352+ rewriter .create <enzyme::InitOp>(
348353 info.initOp ->getLoc (),
349354 enzyme::CacheType::get (cache.getContext (), newType));
350355 });
351356 info.pushOp = ({
352- OpBuilder::InsertionGuard guard (builder );
353- builder .setInsertionPointAfter (newFor);
354- auto newPush = builder .create <enzyme::PushOp>(
357+ OpBuilder::InsertionGuard guard (rewriter );
358+ rewriter .setInsertionPointAfter (newFor);
359+ auto newPush = rewriter .create <enzyme::PushOp>(
355360 cache.getLoc (), newInit.getResult (), newFor->getResult (resultIdx));
356- info.pushOp -> erase ( );
361+ rewriter. eraseOp ( info.pushOp );
357362 newPush;
358363 });
359364
360365 resultIdx++;
361366
362367 {
363- OpBuilder::InsertionGuard guard (builder );
368+ OpBuilder::InsertionGuard guard (rewriter );
364369
365- builder .setInsertionPoint (otherForOp);
370+ rewriter .setInsertionPoint (otherForOp);
366371
367- auto popNewValue = builder .create <enzyme::PopOp>(
372+ auto popNewValue = rewriter .create <enzyme::PopOp>(
368373 info.popOp ->getLoc (), newType, newInit.getResult ());
369374
370375 Block *popBody = otherForOp.getBody ();
371- builder .setInsertionPoint (info.popOp );
376+ rewriter .setInsertionPoint (info.popOp );
372377
373378 Value newInductionVariable =
374379 popBody->getArgument (popBody->getNumArguments () - 1 );
@@ -387,29 +392,27 @@ struct ForOpEnzymeOpsRemover
387392 SmallVector<int64_t > strides (shape.size () + 1 , 1 );
388393
389394 popValue =
390- builder
395+ rewriter
391396 .create <tensor::ExtractSliceOp>(
392397 info.popOp ->getLoc (), TT, popNewValue,
393398 ValueRange (newInductionVariable), ValueRange (),
394- ValueRange (), builder .getDenseI64ArrayAttr (offsets),
395- builder .getDenseI64ArrayAttr (sizes),
396- builder .getDenseI64ArrayAttr (strides))
399+ ValueRange (), rewriter .getDenseI64ArrayAttr (offsets),
400+ rewriter .getDenseI64ArrayAttr (sizes),
401+ rewriter .getDenseI64ArrayAttr (strides))
397402 .getResult ();
398403 } else {
399404 popValue =
400- builder
405+ rewriter
401406 .create <tensor::ExtractOp>(info.popOp ->getLoc (), popNewValue,
402407 newInductionVariable)
403408 .getResult ();
404409 }
405410
406- info.popOp .getResult (). replaceAllUsesWith ( popValue);
407- info.popOp -> erase ( );
411+ rewriter. replaceAllUsesWith ( info.popOp .getResult (), popValue);
412+ rewriter. eraseOp ( info.popOp );
408413 }
409414 }
410415
411- forOp->erase ();
412-
413416 return success ();
414417 }
415418};
0 commit comments