Skip to content

Commit b6756e7

Browse files
author
Maximilian Ehlers
committed
[668]: Enables Pushdown for Rand and Fill with EwLog
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent 5366c85 commit b6756e7

File tree

4 files changed

+78
-13
lines changed

4 files changed

+78
-13
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ mlir::daphne::FillOp pushDownFillIntoEwSub(mlir::daphne::FillOp fillOp, mlir::da
249249
mlir::daphne::EwSubOp newSub = rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), fillValue, scalar);
250250
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newSub, width, height);
251251
}
252+
252253
mlir::daphne::FillOp pushDownFillIntoEwMul(mlir::daphne::FillOp fillOp, mlir::daphne::EwMulOp op, mlir::Value scalar,
253254
mlir::PatternRewriter &rewriter) {
254255
auto fillValue = fillOp.getArg();
@@ -257,6 +258,7 @@ mlir::daphne::FillOp pushDownFillIntoEwMul(mlir::daphne::FillOp fillOp, mlir::da
257258
mlir::daphne::EwMulOp newMul = rewriter.create<mlir::daphne::EwMulOp>(op.getLoc(), fillValue, scalar);
258259
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newMul, width, height);
259260
}
261+
260262
mlir::daphne::FillOp pushDownFillIntoEwDiv(mlir::daphne::FillOp fillOp, mlir::daphne::EwDivOp op, mlir::Value scalar,
261263
mlir::PatternRewriter &rewriter) {
262264
auto fillValue = fillOp.getArg();
@@ -265,6 +267,7 @@ mlir::daphne::FillOp pushDownFillIntoEwDiv(mlir::daphne::FillOp fillOp, mlir::da
265267
mlir::daphne::EwDivOp newDiv = rewriter.create<mlir::daphne::EwDivOp>(op.getLoc(), fillValue, scalar);
266268
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newDiv, width, height);
267269
}
270+
268271
mlir::daphne::FillOp pushDownFillIntoEwPow(mlir::daphne::FillOp fillOp, mlir::daphne::EwPowOp op, mlir::Value scalar,
269272
mlir::PatternRewriter &rewriter) {
270273
auto fillValue = fillOp.getArg();
@@ -273,6 +276,7 @@ mlir::daphne::FillOp pushDownFillIntoEwPow(mlir::daphne::FillOp fillOp, mlir::da
273276
mlir::daphne::EwPowOp newPow = rewriter.create<mlir::daphne::EwPowOp>(op.getLoc(), fillValue, scalar);
274277
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newPow, width, height);
275278
}
279+
// AMLS_TODO: push down naming needs to be other way around
276280
mlir::daphne::FillOp pushDownFillIntoEwMod(mlir::daphne::FillOp fillOp, mlir::daphne::EwModOp op, mlir::Value scalar,
277281
mlir::PatternRewriter &rewriter) {
278282
auto fillValue = fillOp.getArg();
@@ -281,14 +285,17 @@ mlir::daphne::FillOp pushDownFillIntoEwMod(mlir::daphne::FillOp fillOp, mlir::da
281285
mlir::daphne::EwModOp newMod = rewriter.create<mlir::daphne::EwModOp>(op.getLoc(), fillValue, scalar);
282286
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newMod, width, height);
283287
}
288+
284289
mlir::daphne::FillOp pushDownFillIntoEwLog(mlir::daphne::FillOp fillOp, mlir::daphne::EwLogOp op, mlir::Value scalar,
285290
mlir::PatternRewriter &rewriter) {
286291
auto fillValue = fillOp.getArg();
287292
auto height = fillOp.getNumRows();
288293
auto width = fillOp.getNumCols();
289-
// AMLS_TODO: this can lead to error:
294+
// AMLS_TODO: this can lead to error if the log resolves cleanly
295+
// e.g. 8 with base 2
290296
// no kernel for operation `fill` available for the required input types `(si64, index, index)` and output types
291297
// `(!daphne.Matrix<?x?xf64>)
298+
// Problem with Log function?
292299

293300
mlir::daphne::EwLogOp newLog = rewriter.create<mlir::daphne::EwLogOp>(op.getLoc(), fillValue, scalar);
294301
return rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newLog, width, height);
@@ -642,18 +649,28 @@ mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op
642649
*/
643650
mlir::LogicalResult mlir::daphne::EwLogOp::canonicalize(mlir::daphne::EwLogOp op, PatternRewriter &rewriter) {
644651
// AMLS_TODO: reactivate
645-
// mlir::Value lhs = op.getLhs();
646-
// mlir::Value rhs = op.getRhs();
647-
// // This will check for the fill operation to push down the arithmetic inside
648-
// // of it
649-
// // Since the rhs is the base, the FillOp can only appear legally in lhs
650-
// mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
651-
// const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
652-
// if (lhsFill && rhsIsSca) {
653-
// auto newFill = pushDownFillIntoEwLog(lhsFill, op, rhs, rewriter);
654-
// rewriter.replaceOp(op, {newFill});
655-
// return mlir::success();
656-
// }
652+
mlir::Value lhs = op.getLhs();
653+
mlir::Value rhs = op.getRhs();
654+
// This will check for the fill operation to push down the arithmetic inside
655+
// of it
656+
// Since the rhs is the base, the FillOp can only appear legally in lhs
657+
mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
658+
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
659+
if (lhsFill && rhsIsSca) {
660+
auto newFill = pushDownFillIntoEwLog(lhsFill, op, rhs, rewriter);
661+
rewriter.replaceOp(op, {newFill});
662+
return mlir::success();
663+
}
664+
665+
// This will check for the rand operation to push down the arithmetic inside
666+
// of it
667+
// Since the rhs is the base, the RandOp can only appear legally in lhs
668+
mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp<mlir::daphne::RandMatrixOp>();
669+
if (lhsRand && rhsIsSca) {
670+
auto newRand = pushDownRandomIntoEwLog(lhsRand, op, rhs, rewriter);
671+
rewriter.replaceOp(op, {newRand});
672+
return mlir::success();
673+
}
657674
return mlir::failure();
658675
}
659676
/**
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 8.000000e+00 : f64} : () -> f64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 3 : si64} : () -> si64
8+
%3 = "daphne.cast"(%1) : (si64) -> index
9+
%4 = "daphne.cast"(%2) : (si64) -> index
10+
// CHECK: daphne.fill
11+
// CHECK-NOT: daphne.fill
12+
%5 = "daphne.fill"(%0, %3, %4) : (f64, index, index) -> !daphne.Matrix<?x?xf64>
13+
%6 = "daphne.constant"() {value = 2 : si64} : () -> si64
14+
// CHECK-NOT: daphne.ewLog
15+
%7 = "daphne.ewLog"(%5, %6) : (!daphne.Matrix<?x?xf64>, si64) -> !daphne.Matrix<?x?xf64>
16+
%8 = "daphne.constant"() {value = true} : () -> i1
17+
%9 = "daphne.constant"() {value = false} : () -> i1
18+
"daphne.print"(%7, %8, %9) : (!daphne.Matrix<?x?xf64>, i1, i1) -> ()
19+
"daphne.return"() : () -> ()
20+
21+
}
22+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: daphne-opt --canonicalize --inline %s | FileCheck %s
2+
3+
module {
4+
func.func @main() {
5+
%0 = "daphne.constant"() {value = 3 : si64} : () -> si64
6+
%1 = "daphne.constant"() {value = 3 : si64} : () -> si64
7+
%2 = "daphne.constant"() {value = 2 : si64} : () -> si64
8+
%3 = "daphne.constant"() {value = 8 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
10+
%5 = "daphne.constant"() {value = 1 : si64} : () -> si64
11+
%6 = "daphne.ewMinus"(%5) : (si64) -> si64
12+
%7 = "daphne.cast"(%0) : (si64) -> index
13+
%8 = "daphne.cast"(%1) : (si64) -> index
14+
// CHECK: daphne.randMatrix
15+
// CHECK-NOT: daphne.randMatrix
16+
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
17+
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
18+
// CHECK-NOT: daphne.ewLog
19+
%11 = "daphne.ewLog"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xf64>
20+
%12 = "daphne.constant"() {value = true} : () -> i1
21+
%13 = "daphne.constant"() {value = false} : () -> i1
22+
"daphne.print"(%11, %12, %13) : (!daphne.Matrix<?x?xf64>, i1, i1) -> ()
23+
"daphne.return"() : () -> ()
24+
}
25+
}

test/ir/daphneir/pushdown_arithmetics/randEwMul.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module {
1616
%9 = "daphne.randMatrix"(%7, %8, %2, %3, %4, %6) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
1717
%10 = "daphne.constant"() {value = 2 : si64} : () -> si64
1818
// CHECK-NOT: daphne.ewMul
19+
// CHECK-NOT: daphne.ewMul
1920
%11 = "daphne.ewMul"(%9, %10) : (!daphne.Matrix<?x?xsi64>, si64) -> !daphne.Matrix<?x?xsi64>
2021
%12 = "daphne.constant"() {value = true} : () -> i1
2122
%13 = "daphne.constant"() {value = false} : () -> i1

0 commit comments

Comments
 (0)