Skip to content

Commit b38c03b

Browse files
author
Maximilian Ehlers
committed
[668]: adds fill rewrite for EwSubOp,EwMulOp,EwDivOp based on EwAddOp implementation
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent 35e0bf1 commit b38c03b

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,22 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(mlir::daphne::EwAddOp op
323323
mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op, PatternRewriter &rewriter) {
324324
mlir::Value lhs = op.getLhs();
325325
mlir::Value rhs = op.getRhs();
326+
// This will check for the fill operation on the left hand side to push down the arithmetic inside
327+
// of it
328+
mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
329+
if (lhsFill) {
330+
auto fillValue = lhsFill.getArg();
331+
auto height = lhsFill.getNumRows();
332+
auto width = lhsFill.getNumCols();
333+
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
334+
if (rhsIsSca) {
335+
mlir::daphne::EwSubOp newSub = rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), fillValue, rhs);
336+
mlir::daphne::FillOp newFill =
337+
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newSub, width, height);
338+
rewriter.replaceOp(op, {newFill});
339+
return mlir::success();
340+
}
341+
}
326342
const bool lhsIsSca = CompilerUtils::isScaType(lhs.getType());
327343
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
328344
if (lhsIsSca && !rhsIsSca) {
@@ -351,6 +367,23 @@ mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op
351367
mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op, PatternRewriter &rewriter) {
352368
mlir::Value lhs = op.getLhs();
353369
mlir::Value rhs = op.getRhs();
370+
// This will check for the fill operation on the left hand side to push down the arithmetic inside
371+
// of it
372+
mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
373+
if (lhsFill) {
374+
auto fillValue = lhsFill.getArg();
375+
auto height = lhsFill.getNumRows();
376+
auto width = lhsFill.getNumCols();
377+
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
378+
if (rhsIsSca) {
379+
mlir::daphne::EwMulOp newMul = rewriter.create<mlir::daphne::EwMulOp>(op.getLoc(), fillValue, rhs);
380+
mlir::daphne::FillOp newFill =
381+
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newMul , width, height);
382+
rewriter.replaceOp(op, {newFill});
383+
return mlir::success();
384+
}
385+
}
386+
354387
const bool lhsIsSca = CompilerUtils::isScaType(lhs.getType());
355388
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
356389
if (lhsIsSca && !rhsIsSca) {
@@ -376,6 +409,22 @@ mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op
376409
mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op, PatternRewriter &rewriter) {
377410
mlir::Value lhs = op.getLhs();
378411
mlir::Value rhs = op.getRhs();
412+
// This will check for the fill operation on the left hand side to push down the arithmetic inside
413+
// of it
414+
mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
415+
if (lhsFill) {
416+
auto fillValue = lhsFill.getArg();
417+
auto height = lhsFill.getNumRows();
418+
auto width = lhsFill.getNumCols();
419+
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
420+
if (rhsIsSca) {
421+
mlir::daphne::EwDivOp newDiv = rewriter.create<mlir::daphne::EwDivOp>(op.getLoc(), fillValue, rhs);
422+
mlir::daphne::FillOp newFill =
423+
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newDiv, width, height);
424+
rewriter.replaceOp(op, {newFill});
425+
return mlir::success();
426+
}
427+
}
379428
const bool lhsIsSca = CompilerUtils::isScaType(lhs.getType());
380429
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
381430
const bool rhsIsFP = llvm::isa<mlir::FloatType>(CompilerUtils::getValueType(rhs.getType()));
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 4 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 3 : si64} : () -> si64
8+
%3 = "daphne.cast"(%1) : (si64) -> index
9+
%4 = "daphne.cast"(%2) : (si64) -> index
10+
// CHECK: daphne.fill
11+
// CHECK-NOT: daphne.fill
12+
%5 = "daphne.fill"(%0, %3, %4) : (si64, index, index) -> !daphne.Matrix<?x?xsi64>
13+
%6 = "daphne.constant"() {value = 2 : si64} : () -> si64
14+
// CHECK-NOT: daphne.ewDiv
15+
%7 = "daphne.ewDiv"(%5, %6) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
16+
%8 = "daphne.constant"() {value = true} : () -> i1
17+
%9 = "daphne.constant"() {value = false} : () -> i1
18+
"daphne.print"(%7, %8, %9) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
19+
"daphne.return"() : () -> ()
20+
}
21+
}
22+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 4 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 3 : si64} : () -> si64
8+
%3 = "daphne.cast"(%1) : (si64) -> index
9+
%4 = "daphne.cast"(%2) : (si64) -> index
10+
// CHECK: daphne.fill
11+
// CHECK-NOT: daphne.fill
12+
%5 = "daphne.fill"(%0, %3, %4) : (si64, index, index) -> !daphne.Matrix<?x?xsi64>
13+
%6 = "daphne.constant"() {value = 2 : si64} : () -> si64
14+
// CHECK-NOT: daphne.ewMul
15+
%7 = "daphne.ewMul"(%5, %6) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
16+
%8 = "daphne.constant"() {value = true} : () -> i1
17+
%9 = "daphne.constant"() {value = false} : () -> i1
18+
"daphne.print"(%7, %8, %9) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
19+
"daphne.return"() : () -> ()
20+
}
21+
}
22+
23+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 4 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 3 : si64} : () -> si64
8+
%3 = "daphne.cast"(%1) : (si64) -> index
9+
%4 = "daphne.cast"(%2) : (si64) -> index
10+
// CHECK: daphne.fill
11+
// CHECK-NOT: daphne.fill
12+
%5 = "daphne.fill"(%0, %3, %4) : (si64, index, index) -> !daphne.Matrix<?x?xsi64>
13+
%6 = "daphne.constant"() {value = 2 : si64} : () -> si64
14+
// CHECK-NOT: daphne.ewSub
15+
%7 = "daphne.ewSub"(%5, %6) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
16+
%8 = "daphne.constant"() {value = true} : () -> i1
17+
%9 = "daphne.constant"() {value = false} : () -> i1
18+
"daphne.print"(%7, %8, %9) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
19+
"daphne.return"() : () -> ()
20+
}
21+
}

0 commit comments

Comments
 (0)