Skip to content

Commit f4106ca

Browse files
author
Maximilian Ehlers
committed
[668]: fixes tests, removes seqEwDiv which is not suitable for pushDown in integer space, adds docs to tryPushDown
Signed-off-by: Maximilian Ehlers <daphnevm@sodawa.com>
1 parent c889042 commit f4106ca

File tree

13 files changed

+63
-18
lines changed

13 files changed

+63
-18
lines changed

src/ir/daphneir/Canonicalize.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "ir/daphneir/Daphne.h"
1818
#include "ir/daphneir/DaphnePushDownTraits.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/IR/Builders.h"
2021
#include "mlir/Support/LogicalResult.h"
2122
#include <compiler/utils/CompilerUtils.h>
2223
#include <util/DaphneLogger.h>
@@ -329,7 +330,8 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
329330
auto fillValue = lhsFill.getArg();
330331
auto height = lhsFill.getNumRows();
331332
auto width = lhsFill.getNumCols();
332-
auto newOp = rewriter.create<Operation>(op.getLoc(), fillValue, rhs);
333+
auto newOp = rewriter.create<Operation>(op.getLoc(), CompilerUtils::getValueType(op.getResult().getType()),
334+
fillValue, rhs);
333335
auto newCombinedOpAfterPushDown =
334336
rewriter.create<mlir::daphne::FillOp>(op.getLoc(), op.getResult().getType(), newOp, height, width);
335337
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
@@ -363,13 +365,14 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
363365
auto newTo = rewriter.create<Operation>(op.getLoc(), to, rhs);
364366

365367
if (supportsPushDownWithIntervalUpdate) {
366-
mlir::daphne::EwMulOp newInc = rewriter.create<mlir::daphne::EwMulOp>(op.getLoc(), inc, rhs);
368+
auto newInc = rewriter.create<Operation>(op.getLoc(), rhs, inc);
369+
367370
auto newCombinedOpAfterPushDown =
368371
rewriter.create<mlir::daphne::SeqOp>(op.getLoc(), op.getResult().getType(), newFrom, newTo, newInc);
369372
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
370373
return true;
371-
} else {
372374

375+
} else {
373376
auto newCombinedOpAfterPushDown =
374377
rewriter.create<mlir::daphne::SeqOp>(op.getLoc(), op.getResult().getType(), newFrom, newTo, inc);
375378
rewriter.replaceOp(op, {newCombinedOpAfterPushDown});
@@ -379,6 +382,29 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
379382
return false;
380383
}
381384

385+
/**
386+
* @brief PushDown optimizations will try to optimize aways intermediate results
387+
* when function results are used with arithmetic operations.
388+
* An example is `fill(2,2,3) + 3` which will evaluate to
389+
* 3,3
390+
* 3,3
391+
*
392+
* and the applies +3 to each element of the intermediate individually leading to
393+
*
394+
* 6,6
395+
* 6,6
396+
*
397+
* Instead we can push the math inside the function and have it evaluate first:
398+
*
399+
* fill(2,2,3 + 3) -> fill(2,2,6)
400+
*
401+
* which directly fills the matrix
402+
*
403+
* 6,6
404+
* 6,6
405+
*
406+
* and skips the intermediate.
407+
*/
382408
template <class Operation> bool tryPushDown(Operation op, mlir::PatternRewriter &rewriter) {
383409
if constexpr (std::is_same<Operation, mlir::daphne::EwAbsOp>() ||
384410
std::is_same<Operation, mlir::daphne::EwExpOp>() || std::is_same<Operation, mlir::daphne::EwLnOp>() ||

src/ir/daphneir/DaphneOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def Daphne_EwSubOp : Daphne_EwBinaryOp<"ewSub", NumScalar, [ValueTypeFromArgs
362362
def Daphne_EwMulOp : Daphne_EwBinaryOp<"ewMul", NumScalar, [ValueTypeFromArgs, CastArgsToResType, Commutative, EwSparseIfEither, CUDASupport, PushDown, PushDownLinear, PushDownWithIntervalUpdate]> {
363363
let hasCanonicalizeMethod = 1;
364364
}
365-
def Daphne_EwDivOp : Daphne_EwBinaryOp<"ewDiv", NumScalar, [ValueTypeFromArgs, CastArgsToResType, CUDASupport, PushDown, PushDownLinear, PushDownWithIntervalUpdate]> {
365+
def Daphne_EwDivOp : Daphne_EwBinaryOp<"ewDiv", NumScalar, [ValueTypeFromArgs, CastArgsToResType, CUDASupport, PushDown, PushDownLinear]> {
366366
let hasCanonicalizeMethod = 1;
367367
}
368368
def Daphne_EwPowOp : Daphne_EwBinaryOp<"ewPow", NumScalar, [ValueTypeFromArgs, CastArgsToResType, CUDASupport, PushDown]> {

test/api/cli/operations/CanonicalizationPushDownTest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ MAKE_TEST_CASE("randEwSub", 1)
5353
MAKE_TEST_CASE("randEwMul", 1)
5454
MAKE_TEST_CASE("randEwDiv", 1)
5555
MAKE_TEST_CASE("randEwAbs", 2)
56-
MAKE_TEST_CASE("seqEwAdd", 1)
57-
MAKE_TEST_CASE("seqEwSub", 1)
58-
MAKE_TEST_CASE("seqEwMul", 1)
59-
MAKE_TEST_CASE("seqEwDiv", 1)
56+
MAKE_TEST_CASE("seqEwAdd", 2)
57+
MAKE_TEST_CASE("seqEwSub", 2)
58+
MAKE_TEST_CASE("seqEwMul", 2)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
randIntermediate = rand(2,2,-3,1,1.0,7);
1+
randIntermediate = rand(2,2,-3,-1,1.0,7);
22
result = abs(randIntermediate);
33

44
print(result);
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
DenseMatrix(2x2, int64_t)
2-
3 2
3-
2 1
2+
1 1
3+
3 1
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
result = seq(1,5) + 2;
2+
3+
print(result);
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DenseMatrix(5x1, int64_t)
2+
3
3+
4
4+
5
5+
6
6+
7

test/api/cli/operations/seqEwDiv_1.daphne

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/api/cli/operations/seqEwDiv_1.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
result = seq(1,5) * 2;
2+
3+
print(result);

0 commit comments

Comments
 (0)