Skip to content

Commit 58f19f9

Browse files
author
Maximilian Ehlers
committed
[668]: replace create and replaceOp with replaceOpWithNewOp()
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent 8dfca70 commit 58f19f9

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,7 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
271271
auto width = fill.getNumCols();
272272
auto newOp =
273273
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), fillValue);
274-
auto newCombinedOpAfterPushDown =
275-
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newOp, height, width);
276-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
274+
rewriter.replaceOpWithNewOp<mlir::daphne::FillOp>(op, op.getResult().getType(), newOp, height, width);
277275
return true;
278276
}
279277
// This will check for the rand operation to push down the arithmetic inside
@@ -294,15 +292,13 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
294292

295293
if constexpr (std::is_same<Operation, mlir::daphne::EwMinusOp>()) {
296294
// max and min have to be swapped after being negated
297-
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
298-
op.getLoc(), op.getResult().getType(), height, width, newMax, newMin, sparsity, seed);
299-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
295+
rewriter.replaceOpWithNewOp<mlir::daphne::RandMatrixOp>(op, op.getResult().getType(), height, width,
296+
newMax, newMin, sparsity, seed);
300297
return true;
301298
}
302299

303-
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
304-
op.getLoc(), op.getResult().getType(), height, width, newMin, newMax, sparsity, seed);
305-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
300+
rewriter.replaceOpWithNewOp<mlir::daphne::RandMatrixOp>(op, op.getResult().getType(), height, width, newMin,
301+
newMax, sparsity, seed);
306302
return true;
307303
}
308304
// Handle the special case of RandMatrixOp and EwAbsOp which only works
@@ -334,9 +330,8 @@ template <class Operation> bool tryPushDownUnary(Operation op, mlir::PatternRewr
334330
op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), min);
335331
auto newMin = rewriter.create<Operation>(
336332
op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), max);
337-
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
338-
op.getLoc(), op.getResult().getType(), height, width, newMin, newMax, sparsity, seed);
339-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
333+
rewriter.replaceOpWithNewOp<mlir::daphne::RandMatrixOp>(op, op.getResult().getType(), height, width,
334+
newMin, newMax, sparsity, seed);
340335
return true;
341336
}
342337
}
@@ -386,9 +381,7 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
386381
auto width = lhsFill.getNumCols();
387382
auto newOp = rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()),
388383
fillValue, rhs);
389-
auto newCombinedOpAfterPushDown =
390-
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newOp, height, width);
391-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
384+
rewriter.replaceOpWithNewOp<mlir::daphne::FillOp>(op, op.getResult().getType(), newOp, height, width);
392385
return true;
393386
}
394387
// This will check for the rand operation to push down the arithmetic inside
@@ -410,9 +403,8 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
410403
}
411404
auto newMax = rewriter.create<Operation>(op.getLoc(), max, rhs);
412405
auto newMin = rewriter.create<Operation>(op.getLoc(), min, rhs);
413-
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
414-
op.getLoc(), op.getResult().getType(), height, width, newMin, newMax, sparsity, seed);
415-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
406+
rewriter.replaceOpWithNewOp<mlir::daphne::RandMatrixOp>(op, op.getResult().getType(), height, width, newMin,
407+
newMax, sparsity, seed);
416408
return true;
417409
}
418410

@@ -440,15 +432,11 @@ template <class Operation> bool tryPushDownBinary(Operation op, mlir::PatternRew
440432
}
441433
auto newInc = rewriter.create<Operation>(op.getLoc(), rhs, inc);
442434

443-
auto newCombinedOpAfterPushDown =
444-
rewriter.create<mlir::daphne::SeqOp>(op.getLoc(), op.getResult().getType(), newFrom, newTo, newInc);
445-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
435+
rewriter.replaceOpWithNewOp<mlir::daphne::SeqOp>(op, op.getResult().getType(), newFrom, newTo, newInc);
446436
return true;
447437

448438
} else {
449-
auto newCombinedOpAfterPushDown =
450-
rewriter.create<mlir::daphne::SeqOp>(op.getLoc(), op.getResult().getType(), newFrom, newTo, inc);
451-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
439+
rewriter.replaceOpWithNewOp<mlir::daphne::SeqOp>(op, op.getResult().getType(), newFrom, newTo, inc);
452440
return true;
453441
}
454442
}

0 commit comments

Comments
 (0)