Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions src/ir/daphneir/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,205 @@
#include "mlir/Support/LogicalResult.h"
#include <compiler/utils/CompilerUtils.h>

mlir::LogicalResult mlir::daphne::AllAggSumOp::canonicalize(mlir::daphne::AllAggSumOp op,
mlir::PatternRewriter &rewriter) {
mlir::Value input = op.getOperand();
mlir::Location location = op.getLoc();
mlir::Type result_type = op.getResult().getType();
auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext());

// Rule 1: sumAll(ewAdd(X, Y)) to ewAdd(sumAll(X), sumAll(Y))
if (auto addOp = input.getDefiningOp<mlir::daphne::EwAddOp>()) {
// Checking the inputs are matrices
if (!addOp.getLhs().getType().isa<mlir::daphne::MatrixType>() ||
!addOp.getRhs().getType().isa<mlir::daphne::MatrixType>()) {
return mlir::failure();
}

// Individual sums
mlir::Value lSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, addOp.getLhs());
mlir::Value rSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, addOp.getRhs());
mlir::Value scalar_add = rewriter.create<mlir::daphne::EwAddOp>(location, result_type, lSum, rSum);

rewriter.replaceOp(op, scalar_add);
return mlir::success();
} // Rule 2: sumAll(transpose(X)) to sumAll(X)
else if (auto transOp = input.getDefiningOp<mlir::daphne::TransposeOp>()) {
mlir::Value input_tr = transOp.getArg();

// Inputs should be matrices
if (!input_tr.getType().isa<mlir::daphne::MatrixType>()) {
return mlir::failure();
}

mlir::Value simplf_sumOftranspose = rewriter.create<mlir::daphne::AllAggSumOp>(location, result_type, input_tr);
rewriter.replaceOp(op, simplf_sumOftranspose);
return mlir::success();
} // Rule 3: sum(lambda * X) -> lambda * sum(X)
else if (auto lambdaMul = input.getDefiningOp<mlir::daphne::EwMulOp>()) {
mlir::Value left_o = lambdaMul.getLhs();
mlir::Value right_o = lambdaMul.getRhs();

mlir::Value scalarOperand;
mlir::Value matrixOperand;

bool lhsIsSca = CompilerUtils::hasScaType(left_o);
bool rhsIsSca = CompilerUtils::hasScaType(right_o);

// Use .getType() only for matrix detection
bool lhsIsMatrix = left_o.getType().isa<mlir::daphne::MatrixType>();
bool rhsIsMatrix = right_o.getType().isa<mlir::daphne::MatrixType>();

if (lhsIsSca && rhsIsMatrix) {
scalarOperand = left_o;
matrixOperand = right_o;
} else if (rhsIsSca && lhsIsMatrix) {
scalarOperand = right_o;
matrixOperand = left_o;
} else {
return mlir::failure(); // Unsupported combination
}

mlir::Value innerSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, matrixOperand);
mlir::Value newMul = rewriter.create<mlir::daphne::EwMulOp>(location, result_type, scalarOperand, innerSum);
rewriter.replaceOp(op, newMul);

} // Rule 4: trace(X @ Y) = sum(diagVector(X @ Y)) -> sum(X * transpose(Y))
else if (auto diagVec = input.getDefiningOp<mlir::daphne::DiagVectorOp>()) {
mlir::Value input_dV = diagVec.getOperand(); // This should be a matrix (result of MatMul)
if (auto matMul = input_dV.getDefiningOp<mlir::daphne::MatMulOp>()) {
mlir::Value lhs = matMul.getLhs();
mlir::Value rhs = matMul.getRhs();

if (!lhs.getType().isa<mlir::daphne::MatrixType>() || !rhs.getType().isa<mlir::daphne::MatrixType>()) {
return mlir::failure();
}

mlir::Value t_rhs = rewriter.create<mlir::daphne::TransposeOp>(location, unknownType, rhs);
mlir::Value ewMul_m = rewriter.create<mlir::daphne::EwMulOp>(location, unknownType, lhs, t_rhs);
mlir::Value simplifiedSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, result_type, ewMul_m);

rewriter.replaceOp(op, simplifiedSum);
return mlir::success();
}
}

return mlir::failure();
}

/**
* @brief Canonicalizes:
1)(X%*%Y)[7,3] → X[7,]%*%Y[,3]

*/
mlir::LogicalResult mlir::daphne::SliceColOp::canonicalize(mlir::daphne::SliceColOp op,
mlir::PatternRewriter &rewriter) {
mlir::Value input = op.getOperand(0);
mlir::Location location = op.getLoc();
mlir::Type result_type = op.getResult().getType();
auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext());

auto sliceRowOp = input.getDefiningOp<mlir::daphne::SliceRowOp>();
if (!sliceRowOp) {
return mlir::failure();
}

auto matMulOp = sliceRowOp.getOperand(0).getDefiningOp<mlir::daphne::MatMulOp>();
if (!matMulOp) {
return mlir::failure();
}

// matrices
mlir::Value X = matMulOp.getLhs();
mlir::Value Y = matMulOp.getRhs();

// lower-upper bounds for rows
mlir::Value row_l = sliceRowOp.getOperand(1);
mlir::Value row_u = sliceRowOp.getOperand(2);

// lower-upper bounds for columns
mlir::Value col_l = op.getOperand(1);
mlir::Value col_u = op.getOperand(2);

// to check if a matrix is transposed
mlir::Value t_X = matMulOp.getOperand(2);
mlir::Value t_Y = matMulOp.getOperand(3);

bool isTransposedX = CompilerUtils::isConstant<bool>(t_X).second;
bool isTransposedY = CompilerUtils::isConstant<bool>(t_Y).second;

mlir::Value row;
mlir::Value col;

if (!isTransposedX && !isTransposedY) {
row = rewriter.create<mlir::daphne::SliceRowOp>(location, unknownType, X, row_l, row_u);
col = rewriter.create<mlir::daphne::SliceColOp>(location, unknownType, Y, col_l, col_u);

auto newMatMul = rewriter.create<mlir::daphne::MatMulOp>(location, result_type, row, col, t_X, t_Y);
rewriter.replaceOp(op, newMatMul.getResult());
return mlir::success();

} else {
return mlir::failure();
}
}

/** @brief Canonicalizes:
1)X[a:b, c:d] = Y -> X=Y if dims(X) = dims(Y)
//only for matrices with matching element types
*/
mlir::LogicalResult mlir::daphne::InsertRowOp::canonicalize(mlir::daphne::InsertRowOp op,
mlir::PatternRewriter &rewriter) {
mlir::Location location = op.getLoc();
mlir::Type result_type = op.getResult().getType();

auto insertCol = op.getIns().getDefiningOp<mlir::daphne::InsertColOp>();
if (!insertCol) {
return mlir::failure();
}

auto sliceRow = insertCol.getArg().getDefiningOp<mlir::daphne::SliceRowOp>();
if (!sliceRow) {
return mlir::failure();
}

mlir::Value sliceInput = sliceRow.getSource(); // X
mlir::Value insertColInput = insertCol.getIns(); // Y
if (!sliceInput.getType().isa<mlir::daphne::MatrixType>() ||
!insertColInput.getType().isa<mlir::daphne::MatrixType>()) {
return mlir::failure();
}

auto sliceType = sliceInput.getType().dyn_cast<mlir::daphne::MatrixType>();
auto insertColInputType = insertColInput.getType().dyn_cast<mlir::daphne::MatrixType>();
auto opResultType = op.getResult().getType().dyn_cast<mlir::daphne::MatrixType>();

if (!sliceType || !insertColInputType || !opResultType) {
return mlir::failure();
}

if (sliceType.getElementType() != insertColInputType.getElementType()) {
return mlir::failure();
}

int64_t numRows_X = sliceType.getNumRows();
int64_t numCols_X = sliceType.getNumCols();
int64_t numRows_Y = insertColInputType.getNumRows();
int64_t numCols_Y = insertColInputType.getNumCols();

if (numRows_X == -1 || numCols_X == -1 || numRows_Y == -1 || numCols_Y == -1) {
return mlir::failure();
}

if (numRows_X != numRows_Y || numCols_X != numCols_Y) {
return mlir::failure();
}

auto renamed = rewriter.create<mlir::daphne::RenameOp>(location, result_type, insertColInput);
rewriter.replaceOp(op, renamed.getResult());
return mlir::success();
}

mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphne::VectorizedPipelineOp op,
mlir::PatternRewriter &rewriter) {
// // Find duplicate inputs
Expand Down
18 changes: 13 additions & 5 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def Daphne_MatMulOp : Daphne_Op<"matMul", [
DataTypeMat, ValueTypeFromArgs,
DeclareOpInterfaceMethods<InferShapeOpInterface>,
DeclareOpInterfaceMethods<InferSparsityOpInterface>, CUDASupport, FPGAOPENCLSupport,
CastFirstTwoArgsToResType, NoMemoryEffect
CastFirstTwoArgsToResType, NoMemoryEffect,
Pure
]> {
let arguments = (ins MatrixOf<[NumScalar]>:$lhs, MatrixOf<[NumScalar]>:$rhs, BoolScalar:$transa, BoolScalar:$transb);
let results = (outs MatrixOf<[NumScalar]>:$res);
Expand Down Expand Up @@ -498,7 +499,9 @@ class Daphne_AllAggOp<string name, Type scalarType, list<Trait> traits = []>
let results = (outs scalarType:$res);
}

def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>;
def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>{
let hasCanonicalizeMethod = 1;
}
def Daphne_AllAggMinOp : Daphne_AllAggOp<"minAll", NumScalar, [ValueTypeFromFirstArg]>;
def Daphne_AllAggMaxOp : Daphne_AllAggOp<"maxAll", NumScalar, [ValueTypeFromFirstArg]>;
def Daphne_AllAggMeanOp : Daphne_AllAggOp<"meanAll", NumScalar, [ValueTypeFromArgsFP]>;
Expand Down Expand Up @@ -649,7 +652,8 @@ def Daphne_ExtractRowOp : Daphne_Op<"extractRow", [

def Daphne_SliceRowOp : Daphne_Op<"sliceRow", [
TypeFromFirstArg,
DeclareOpInterfaceMethods<InferShapeOpInterface>
DeclareOpInterfaceMethods<InferShapeOpInterface>,
Pure
]> {
let summary = "Copies the specified rows from the argument to the result.";

Expand Down Expand Up @@ -704,6 +708,7 @@ def Daphne_SliceColOp : Daphne_Op<"sliceCol", [

let arguments = (ins MatrixOrFrame:$source, SI64:$lowerIncl, SI64:$upperExcl);
let results = (outs MatrixOrFrame:$res);
let hasCanonicalizeMethod = 1;
}

// TODO Create combined InsertOp (see #238).
Expand All @@ -714,11 +719,13 @@ def Daphne_InsertRowOp : Daphne_Op<"insertRow", [
]> {
let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$rowLowerIncl, SI64:$rowUpperExcl);
let results = (outs MatrixOrFrame:$res);
let hasCanonicalizeMethod = 1;
}

def Daphne_InsertColOp : Daphne_Op<"insertCol", [
TypeFromFirstArg, // this is debatable
ShapeFromArg
ShapeFromArg,
Pure
]> {
let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$colLowerIncl, SI64:$colUpperExcl);
let results = (outs MatrixOrFrame:$res);
Expand Down Expand Up @@ -989,7 +996,8 @@ def Daphne_SoftmaxOp : Daphne_Op<"Softmax", [ DataTypeFromFirstArg, ValueTypeFro
// ****************************************************************************

def Daphne_DiagVectorOp : Daphne_Op<"diagVector", [
TypeFromFirstArg, NumRowsFromArg, OneCol
TypeFromFirstArg, NumRowsFromArg, OneCol,
Pure
]> {
let arguments = (ins MatrixOrU:$arg);
let results = (outs MatrixOrU:$res);
Expand Down
3 changes: 2 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ include_directories(${PROJECT_SOURCE_DIR}/thirdparty/catch2) # for "catch.hpp"
set(TEST_SOURCES
run_tests.h
run_tests.cpp


api/cli/expressions/SimplificationTest.cpp
api/cli/algorithms/AlgorithmsTest.cpp
api/cli/algorithms/DecisionTreeRandomForestTest.cpp
api/cli/config/ConfigTest.cpp
Expand Down
24 changes: 24 additions & 0 deletions test/api/cli/expressions/SimplificationTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <api/cli/Utils.h>

#include <tags.h>

#include <catch.hpp>

#include <sstream>
#include <string>

const std::string dirPath = "test/api/cli/expressions/";

#define MAKE_TEST_CASE(name, count) \
TEST_CASE(name, TAG_REWRITE) { \
for (unsigned i = 1; i <= count; i++) { \
DYNAMIC_SECTION(name "_" << i << ".daphne") { compareDaphneToRefSimple(dirPath, name, i); } \
} \
}

MAKE_TEST_CASE("simplf_sumEwadd", 1)
MAKE_TEST_CASE("simplf_sumTranspose", 1)
MAKE_TEST_CASE("simplf_sumMulLambda", 1)
MAKE_TEST_CASE("simplf_sumTrace", 1)
MAKE_TEST_CASE("simplf_mmSlice", 1)
MAKE_TEST_CASE("simplf_dynInsert", 1)
15 changes: 15 additions & 0 deletions test/api/cli/expressions/simplf_dynInsert_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

m1 = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0
];
m1 = reshape(m1, 5, 5);

m2 = fill(0.0, 5, 5);
m2 [0:5, 0:5] = m1;
print(m1);
print(m2);

m3 = fill(0.0, 7, 8);
m3 [0:5, 0:5] = m1;
print(m3);
20 changes: 20 additions & 0 deletions test/api/cli/expressions/simplf_dynInsert_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
DenseMatrix(5x5, double)
1 2 3 4 5
6 7 8 9 10
11 12 13 14 15
16 17 18 19 20
21 22 23 24 25
DenseMatrix(5x5, double)
1 2 3 4 5
6 7 8 9 10
11 12 13 14 15
16 17 18 19 20
21 22 23 24 25
DenseMatrix(7x8, double)
1 2 3 4 5 0 0 0
6 7 8 9 10 0 0 0
11 12 13 14 15 0 0 0
16 17 18 19 20 0 0 0
21 22 23 24 25 0 0 0
0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0
Loading