Skip to content

Commit 069cf4d

Browse files
author
Maximilian Ehlers
committed
[668]: adds fill rewrite for EwSubOp
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent 35e0bf1 commit 069cf4d

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,22 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(mlir::daphne::EwAddOp op
323323
mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op, PatternRewriter &rewriter) {
324324
mlir::Value lhs = op.getLhs();
325325
mlir::Value rhs = op.getRhs();
326+
// This will check for the fill operation on the left hand side to push down the arithmetic inside
327+
// of it
328+
mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
329+
if (lhsFill) {
330+
auto fillValue = lhsFill.getArg();
331+
auto height = lhsFill.getNumRows();
332+
auto width = lhsFill.getNumCols();
333+
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
334+
if (rhsIsSca) {
335+
mlir::daphne::EwAddOp newSub = rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), fillValue, rhs);
336+
mlir::daphne::FillOp newFill =
337+
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newAdd, width, height);
338+
rewriter.replaceOp(op, {newFill});
339+
return mlir::success();
340+
}
341+
}
326342
const bool lhsIsSca = CompilerUtils::isScaType(lhs.getType());
327343
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
328344
if (lhsIsSca && !rhsIsSca) {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 1 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 4 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 3 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 3 : si64} : () -> si64
9+
%4 = "daphne.cast"(%2) : (si64) -> index
10+
%5 = "daphne.cast"(%3) : (si64) -> index
11+
// CHECK: daphne.fill
12+
// CHECK-NOT: daphne.fill
13+
%6 = "daphne.fill"(%1, %4, %5) : (si64, index, index) -> !daphne.Matrix<?x?xsi64>
14+
// CHECK-NOT: daphne.ewSub
15+
%7 = "daphne.ewSub"(%0, %6) : (si64, !daphne.Matrix<?x?xsi64>) -> !daphne.Matrix<?x?xsi64>
16+
%8 = "daphne.constant"() {value = true} : () -> i1
17+
%9 = "daphne.constant"() {value = false} : () -> i1
18+
"daphne.print"(%7, %8, %9) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
19+
"daphne.return"() : () -> ()
20+
}
21+
}
22+

0 commit comments

Comments
 (0)