@@ -271,9 +271,7 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
271271 auto width = fill.getNumCols ();
272272 auto newOp =
273273 rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), fillValue);
274- auto newCombinedOpAfterPushDown =
275- rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newOp, height, width);
276- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
274+ rewriter.replaceOpWithNewOp <mlir::daphne::FillOp>(op, op.getResult ().getType (), newOp, height, width);
277275 return true ;
278276 }
279277 // This will check for the rand operation to push down the arithmetic inside
@@ -294,15 +292,13 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
294292
295293 if constexpr (std::is_same<Operation, mlir::daphne::EwMinusOp>()) {
296294 // max and min have to be swapped after being negated
297- auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
298- op.getLoc (), op.getResult ().getType (), height, width, newMax, newMin, sparsity, seed);
299- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
295+ rewriter.replaceOpWithNewOp <mlir::daphne::RandMatrixOp>(op, op.getResult ().getType (), height, width,
296+ newMax, newMin, sparsity, seed);
300297 return true ;
301298 }
302299
303- auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
304- op.getLoc (), op.getResult ().getType (), height, width, newMin, newMax, sparsity, seed);
305- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
300+ rewriter.replaceOpWithNewOp <mlir::daphne::RandMatrixOp>(op, op.getResult ().getType (), height, width, newMin,
301+ newMax, sparsity, seed);
306302 return true ;
307303 }
308304 // Handle the special case of RandMatrixOp and EwAbsOp which only works
@@ -334,9 +330,8 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
334330 op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), min);
335331 auto newMin = rewriter.create <Operation>(
336332 op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), max);
337- auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
338- op.getLoc (), op.getResult ().getType (), height, width, newMin, newMax, sparsity, seed);
339- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
333+ rewriter.replaceOpWithNewOp <mlir::daphne::RandMatrixOp>(op, op.getResult ().getType (), height, width,
334+ newMin, newMax, sparsity, seed);
340335 return true ;
341336 }
342337 }
@@ -386,9 +381,7 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
386381 auto width = lhsFill.getNumCols ();
387382 auto newOp = rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()),
388383 fillValue, rhs);
389- auto newCombinedOpAfterPushDown =
390- rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newOp, height, width);
391- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
384+ rewriter.replaceOpWithNewOp <mlir::daphne::FillOp>(op, op.getResult ().getType (), newOp, height, width);
392385 return true ;
393386 }
394387 // This will check for the rand operation to push down the arithmetic inside
@@ -410,9 +403,8 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
410403 }
411404 auto newMax = rewriter.create <Operation>(op.getLoc (), max, rhs);
412405 auto newMin = rewriter.create <Operation>(op.getLoc (), min, rhs);
413- auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
414- op.getLoc (), op.getResult ().getType (), height, width, newMin, newMax, sparsity, seed);
415- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
406+ rewriter.replaceOpWithNewOp <mlir::daphne::RandMatrixOp>(op, op.getResult ().getType (), height, width, newMin,
407+ newMax, sparsity, seed);
416408 return true ;
417409 }
418410
@@ -440,15 +432,11 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
440432 }
441433 auto newInc = rewriter.create <Operation>(op.getLoc (), rhs, inc);
442434
443- auto newCombinedOpAfterPushDown =
444- rewriter.create <mlir::daphne::SeqOp>(op.getLoc (), op.getResult ().getType (), newFrom, newTo, newInc);
445- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
435+ rewriter.replaceOpWithNewOp <mlir::daphne::SeqOp>(op, op.getResult ().getType (), newFrom, newTo, newInc);
446436 return true ;
447437
448438 } else {
449- auto newCombinedOpAfterPushDown =
450- rewriter.create <mlir::daphne::SeqOp>(op.getLoc (), op.getResult ().getType (), newFrom, newTo, inc);
451- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
439+ rewriter.replaceOpWithNewOp <mlir::daphne::SeqOp>(op, op.getResult ().getType (), newFrom, newTo, inc);
452440 return true ;
453441 }
454442 }
0 commit comments