Skip to content

Commit 316895c

Browse files
author
Maximilian Ehlers
committed
[668]: removes nonlinear rand pushdowns and updates the special randAbs case.
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent a3b2acc commit 316895c

File tree

7 files changed

+52
-109
lines changed

7 files changed

+52
-109
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -255,22 +255,59 @@ template <class Operation> bool pushDownUnary(Operation op, mlir::PatternRewrite
255255
}
256256
// This will check for the rand operation to push down the arithmetic inside
257257
// of it
258-
if (rand && supportsPushDownLinear) {
258+
if (rand && supportsPushDown) {
259259
auto max = rand.getMax();
260260
auto min = rand.getMin();
261261
auto height = rand.getNumRows();
262262
auto width = rand.getNumCols();
263263
auto sparsity = rand.getSparsity();
264264
auto seed = rand.getSeed();
265-
auto newMax =
266-
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), max);
267-
auto newMin =
268-
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), min);
269265

270-
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
271-
op.getLoc(), op.getResult().getType(), width, height, newMin, newMax, sparsity, seed);
272-
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
273-
return true;
266+
if (supportsPushDownLinear) {
267+
auto newMax =
268+
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), max);
269+
auto newMin =
270+
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), min);
271+
272+
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
273+
op.getLoc(), op.getResult().getType(), width, height, newMin, newMax, sparsity, seed);
274+
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
275+
return true;
276+
}
277+
// Handle the special case of RandMatrixOp and EwAbsOp which only works
278+
// if the max and min are both positive (Abs is no-op) or both are
279+
// negative (swap them)
280+
if constexpr (std::is_same<Operation, mlir::daphne::EwAbsOp>()) {
281+
auto maxValueInt = CompilerUtils::isConstant<int>(max);
282+
auto minValueInt = CompilerUtils::isConstant<int>(min);
283+
auto maxValueDouble = CompilerUtils::isConstant<double>(max);
284+
auto minValueDouble = CompilerUtils::isConstant<double>(min);
285+
286+
// will be int or double. Whichever it isn't will default to 0
287+
// so they can simply be added together here
288+
289+
auto maxValue = maxValueDouble.second + maxValueInt.second;
290+
auto minValue = minValueDouble.second + minValueInt.second;
291+
if (minValue >= 0 && maxValue > minValue) {
292+
// simply remove Abs function
293+
294+
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
295+
op.getLoc(), op.getResult().getType(), width, height, min, max, sparsity, seed);
296+
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
297+
return true;
298+
}
299+
if (minValue <= 0 && maxValue > minValue) {
300+
// swap max and min
301+
auto newMax =
302+
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), min);
303+
auto newMin =
304+
rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()), max);
305+
auto newCombinedOpAfterPushDown = rewriter.create<mlir::daphne::RandMatrixOp>(
306+
op.getLoc(), op.getResult().getType(), width, height, newMin, newMax, sparsity, seed);
307+
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
308+
return true;
309+
}
310+
}
274311
}
275312
return false;
276313
}
@@ -300,7 +337,7 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
300337
}
301338
// This will check for the rand operation to push down the arithmetic inside
302339
// of it
303-
if (lhsRand && rhsIsSca && supportsPushDown) {
340+
if (lhsRand && rhsIsSca && supportsPushDown && supportsPushDownLinear) {
304341
auto max = lhsRand.getMax();
305342
auto min = lhsRand.getMin();
306343
auto height = lhsRand.getNumRows();
@@ -317,7 +354,7 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
317354

318355
// This will check for the seq operation to push down the arithmetic inside
319356
// of it
320-
if (lhsSeq && rhsIsSca && supportsPushDownLinear) {
357+
if (lhsSeq && rhsIsSca && supportsPushDown && supportsPushDownLinear) {
321358
auto from = lhsSeq.getFrom();
322359
auto to = lhsSeq.getTo();
323360
auto inc = lhsSeq.getInc();
@@ -357,7 +394,6 @@ template <class Operation> bool tryPushDown(Operation op, mlir::PatternRewriter
357394
std::is_same<Operation, mlir::daphne::EwLogOp>() || std::is_same<Operation, mlir::daphne::EwModOp>()
358395

359396
) {
360-
spdlog::warn("binary");
361397
return pushDownBinary(op, rewriter);
362398
}
363399
return false;
@@ -381,7 +417,6 @@ template <class Operation> bool tryPushDown(Operation op, mlir::PatternRewriter
381417
mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(mlir::daphne::EwAddOp op, PatternRewriter &rewriter) {
382418
mlir::Value lhs = op.getLhs();
383419
mlir::Value rhs = op.getRhs();
384-
const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
385420
if (tryPushDown<mlir::daphne::EwAddOp>(op, rewriter)) {
386421
return mlir::success();
387422
}
File renamed without changes.

test/ir/daphneir/pushdown_arithmetics/randEwSqrt.mlir renamed to test/ir/daphneir/pushdown_arithmetics/randEwAbs_2.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ module {
66
%1 = "daphne.constant"() {value = -1 : si64} : () -> si64
77
%2 = "daphne.constant"() {value = false} : () -> i1
88
%3 = "daphne.constant"() {value = true} : () -> i1
9-
%4 = "daphne.constant"() {value = 9 : si64} : () -> si64
10-
%5 = "daphne.constant"() {value = 4 : si64} : () -> si64
9+
%4 = "daphne.constant"() {value = -2 : si64} : () -> si64
10+
%5 = "daphne.constant"() {value = -3 : si64} : () -> si64
1111
%6 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64
1212
// CHECK: daphne.randMatrix
1313
// CHECK-NOT: daphne.randMatrix
1414
%7 = "daphne.randMatrix"(%0, %0, %5, %4, %6, %1) : (index, index, si64, si64, f64, si64) -> !daphne.Matrix<?x?xsi64>
15-
// CHECK-NOT: daphne.ewSqrt
16-
%8 = "daphne.ewSqrt"(%7) : (!daphne.Matrix<?x?xsi64>) -> !daphne.Matrix<?x?xsi64>
15+
// CHECK-NOT: daphne.ewAbs
16+
%8 = "daphne.ewAbs"(%7) : (!daphne.Matrix<?x?xsi64>) -> !daphne.Matrix<?x?xsi64>
1717
"daphne.print"(%8, %3, %2) : (!daphne.Matrix<?x?xsi64>, i1, i1) -> ()
1818
"daphne.return"() : () -> ()
1919
}

test/ir/daphneir/pushdown_arithmetics/randEwExp.mlir

Lines changed: 0 additions & 20 deletions
This file was deleted.

test/ir/daphneir/pushdown_arithmetics/randEwLn.mlir

Lines changed: 0 additions & 20 deletions
This file was deleted.

test/ir/daphneir/pushdown_arithmetics/randEwLog.mlir

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/ir/daphneir/pushdown_arithmetics/randEwPow.mlir

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)