Skip to content

Commit 19e26e3

Browse files
author
Maximilian Ehlers
committed
[668]: Adds pushdown arithmetics for Rand + EwMul,EwDiv,EwSub,EwPow
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent f7b82d2 commit 19e26e3

File tree

5 files changed

+230
-0
lines changed

5 files changed

+230
-0
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,71 @@ mlir::daphne::RandMatrixOp pushDownRandomIntoEwAdd(mlir::daphne::RandMatrixOp ra
307307
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
308308
newMax, sparsity, seed);
309309
}
310+
mlir::daphne::RandMatrixOp pushDownRandomIntoEwSub(mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwSubOp op,
311+
mlir::Value scalar, mlir::PatternRewriter &rewriter) {
312+
auto max = randOp.getMax();
313+
auto min = randOp.getMin();
314+
auto height = randOp.getNumRows();
315+
auto width = randOp.getNumCols();
316+
auto sparsity = randOp.getSparsity();
317+
auto seed = randOp.getSeed();
318+
mlir::daphne::EwSubOp newMax = rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), max, scalar);
319+
mlir::daphne::EwSubOp newMin = rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), min, scalar);
320+
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
321+
newMax, sparsity, seed);
322+
}
323+
mlir::daphne::RandMatrixOp pushDownRandomIntoEwMul(mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwMulOp op,
324+
mlir::Value scalar, mlir::PatternRewriter &rewriter) {
325+
auto max = randOp.getMax();
326+
auto min = randOp.getMin();
327+
auto height = randOp.getNumRows();
328+
auto width = randOp.getNumCols();
329+
auto sparsity = randOp.getSparsity();
330+
auto seed = randOp.getSeed();
331+
mlir::daphne::EwMulOp newMax = rewriter.create<mlir::daphne::EwMulOp>(op.getLoc(), max, scalar);
332+
mlir::daphne::EwMulOp newMin = rewriter.create<mlir::daphne::EwMulOp>(op.getLoc(), min, scalar);
333+
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
334+
newMax, sparsity, seed);
335+
}
336+
mlir::daphne::RandMatrixOp pushDownRandomIntoEwDiv(mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwDivOp op,
337+
mlir::Value scalar, mlir::PatternRewriter &rewriter) {
338+
auto max = randOp.getMax();
339+
auto min = randOp.getMin();
340+
auto height = randOp.getNumRows();
341+
auto width = randOp.getNumCols();
342+
auto sparsity = randOp.getSparsity();
343+
auto seed = randOp.getSeed();
344+
mlir::daphne::EwDivOp newMax = rewriter.create<mlir::daphne::EwDivOp>(op.getLoc(), max, scalar);
345+
mlir::daphne::EwDivOp newMin = rewriter.create<mlir::daphne::EwDivOp>(op.getLoc(), min, scalar);
346+
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
347+
newMax, sparsity, seed);
348+
}
349+
mlir::daphne::RandMatrixOp pushDownRandomIntoEwPow(mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwPowOp op,
350+
mlir::Value scalar, mlir::PatternRewriter &rewriter) {
351+
auto max = randOp.getMax();
352+
auto min = randOp.getMin();
353+
auto height = randOp.getNumRows();
354+
auto width = randOp.getNumCols();
355+
auto sparsity = randOp.getSparsity();
356+
auto seed = randOp.getSeed();
357+
mlir::daphne::EwPowOp newMax = rewriter.create<mlir::daphne::EwPowOp>(op.getLoc(), max, scalar);
358+
mlir::daphne::EwPowOp newMin = rewriter.create<mlir::daphne::EwPowOp>(op.getLoc(), min, scalar);
359+
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
360+
newMax, sparsity, seed);
361+
}
362+
mlir::daphne::RandMatrixOp pushDownRandomIntoEwLog(mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwLogOp op,
363+
mlir::Value scalar, mlir::PatternRewriter &rewriter) {
364+
auto max = randOp.getMax();
365+
auto min = randOp.getMin();
366+
auto height = randOp.getNumRows();
367+
auto width = randOp.getNumCols();
368+
auto sparsity = randOp.getSparsity();
369+
auto seed = randOp.getSeed();
370+
mlir::daphne::EwLogOp newMax = rewriter.create<mlir::daphne::EwLogOp>(op.getLoc(), max, scalar);
371+
mlir::daphne::EwLogOp newMin = rewriter.create<mlir::daphne::EwLogOp>(op.getLoc(), min, scalar);
372+
return rewriter.create<mlir::daphne::RandMatrixOp>(op.getLoc(), op.getResult().getType(), width, height, newMin,
373+
newMax, sparsity, seed);
374+
}
310375

311376
/**
312377
* @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string,
@@ -431,6 +496,20 @@ mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op
431496
rewriter.replaceOp(op, {newFill});
432497
return mlir::success();
433498
}
499+
// This will check for the rand operation to push down the arithmetic inside
500+
// of it
501+
mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
502+
mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
503+
if (lhsRand && rhsIsSca) {
504+
auto newRand = pushDownRandomIntoEwSub(lhsRand, op, rhs, rewriter);
505+
rewriter.replaceOp(op, {newRand});
506+
return mlir::success();
507+
}
508+
if (rhsRand && lhsIsSca) {
509+
auto newRand = pushDownRandomIntoEwSub(rhsRand, op, lhs, rewriter);
510+
rewriter.replaceOp(op, {newRand});
511+
return mlir::success();
512+
}
434513

435514
if (lhsIsSca && !rhsIsSca) {
436515
rewriter.replaceOpWithNewOp<mlir::daphne::EwAddOp>(
@@ -475,6 +554,21 @@ mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op
475554
return mlir::success();
476555
}
477556

557+
// This will check for the rand operation to push down the arithmetic inside
558+
// of it
559+
mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
560+
mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
561+
if (lhsRand && rhsIsSca) {
562+
auto newRand = pushDownRandomIntoEwMul(lhsRand, op, rhs, rewriter);
563+
rewriter.replaceOp(op, {newRand});
564+
return mlir::success();
565+
}
566+
if (rhsRand && lhsIsSca) {
567+
auto newRand = pushDownRandomIntoEwMul(rhsRand, op, lhs, rewriter);
568+
rewriter.replaceOp(op, {newRand});
569+
return mlir::success();
570+
}
571+
478572
if (lhsIsSca && !rhsIsSca) {
479573
rewriter.replaceOpWithNewOp<mlir::daphne::EwMulOp>(op, op.getResult().getType(), rhs, lhs);
480574
return mlir::success();
@@ -515,6 +609,21 @@ mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op
515609
return mlir::success();
516610
}
517611

612+
// This will check for the rand operation to push down the arithmetic inside
613+
// of it
614+
mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
615+
mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
616+
if (lhsRand && rhsIsSca) {
617+
auto newRand = pushDownRandomIntoEwDiv(lhsRand, op, rhs, rewriter);
618+
rewriter.replaceOp(op, {newRand});
619+
return mlir::success();
620+
}
621+
if (rhsRand && lhsIsSca) {
622+
auto newRand = pushDownRandomIntoEwDiv(rhsRand, op, lhs, rewriter);
623+
rewriter.replaceOp(op, {newRand});
624+
return mlir::success();
625+
}
626+
518627
const bool rhsIsFP = llvm::isa<mlir::FloatType>(CompilerUtils::getValueType(rhs.getType()));
519628
if (lhsIsSca && !rhsIsSca && rhsIsFP) {
520629
rewriter.replaceOpWithNewOp<mlir::daphne::EwMulOp>(
@@ -594,6 +703,20 @@ mlir::LogicalResult mlir::daphne::EwPowOp::canonicalize(mlir::daphne::EwPowOp op
594703
rewriter.replaceOp(op, {newFill});
595704
return mlir::success();
596705
}
706+
// This will check for the rand operation to push down the arithmetic inside
707+
// of it
708+
mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
709+
mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
710+
if (lhsRand && rhsIsSca) {
711+
auto newRand = pushDownRandomIntoEwPow(lhsRand, op, rhs, rewriter);
712+
rewriter.replaceOp(op, {newRand});
713+
return mlir::success();
714+
}
715+
if (rhsRand && lhsIsSca) {
716+
auto newRand = pushDownRandomIntoEwPow(rhsRand, op, lhs, rewriter);
717+
rewriter.replaceOp(op, {newRand});
718+
return mlir::success();
719+
}
597720
return mlir::failure();
598721
}
599722

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 3 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 0 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 8 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
10+
%5 = "daphne.constant"() {value = 1 : si64} : () -> si64
11+
%6 = "daphne.ewMinus"(%5) : (si64) -> si64
12+
%7 = "daphne.cast"(%0) : (si64) -> index
13+
%8 = "daphne.cast"(%1) : (si64) -> index
14+
// CHECK: daphne.randMatrix
15+
// CHECK-NOT: daphne.randMatrix
16+
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
17+
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
18+
// CHECK-NOT: daphne.ewDiv
19+
%11 = "daphne.ewDiv"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
20+
%12 = "daphne.constant"() {value = true} : () -> i1
21+
%13 = "daphne.constant"() {value = false} : () -> i1
22+
"daphne.print"(%11, %12, %13) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
23+
"daphne.return"() : () -> ()
24+
25+
}
26+
}
27+
28+
29+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 3 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 0 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 3 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
10+
%5 = "daphne.constant"() {value = 1 : si64} : () -> si64
11+
%6 = "daphne.ewMinus"(%5) : (si64) -> si64
12+
%7 = "daphne.cast"(%0) : (si64) -> index
13+
%8 = "daphne.cast"(%1) : (si64) -> index
14+
// CHECK: daphne.randMatrix
15+
// CHECK-NOT: daphne.randMatrix
16+
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
17+
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
18+
// CHECK-NOT: daphne.ewMul
19+
%11 = "daphne.ewMul"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
20+
%12 = "daphne.constant"() {value = true} : () -> i1
21+
%13 = "daphne.constant"() {value = false} : () -> i1
22+
"daphne.print"(%11, %12, %13) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
23+
"daphne.return"() : () -> ()
24+
25+
}
26+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 3 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 0 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 3 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
10+
%5 = "daphne.constant"() {value = 1 : si64} : () -> si64
11+
%6 = "daphne.ewMinus"(%5) : (si64) -> si64
12+
%7 = "daphne.cast"(%0) : (si64) -> index
13+
%8 = "daphne.cast"(%1) : (si64) -> index
14+
// CHECK: daphne.randMatrix
15+
// CHECK-NOT: daphne.randMatrix
16+
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
17+
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
18+
// CHECK-NOT: daphne.ewPow
19+
%11 = "daphne.ewPow"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
20+
%12 = "daphne.constant"() {value = true} : () -> i1
21+
%13 = "daphne.constant"() {value = false} : () -> i1
22+
"daphne.print"(%11, %12, %13) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
23+
"daphne.return"() : () -> ()
24+
}
25+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 3 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 0 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 3 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
10+
%5 = "daphne.constant"() {value = 1 : si64} : () -> si64
11+
%6 = "daphne.ewMinus"(%5) : (si64) -> si64
12+
%7 = "daphne.cast"(%0) : (si64) -> index
13+
%8 = "daphne.cast"(%1) : (si64) -> index
14+
// CHECK: daphne.randMatrix
15+
// CHECK-NOT: daphne.randMatrix
16+
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
17+
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
18+
// CHECK-NOT: daphne.ewSub
19+
%11 = "daphne.ewSub"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
20+
%12 = "daphne.constant"() {value = true} : () -> i1
21+
%13 = "daphne.constant"() {value = false} : () -> i1
22+
"daphne.print"(%11, %12, %13) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
23+
"daphne.return"() : () -> ()
24+
}
25+
}
26+
27+

0 commit comments

Comments
 (0)