diff --git a/src/ir/daphneir/Canonicalize.cpp b/src/ir/daphneir/Canonicalize.cpp index e769bdc3b..706f5fd57 100644 --- a/src/ir/daphneir/Canonicalize.cpp +++ b/src/ir/daphneir/Canonicalize.cpp @@ -19,6 +19,205 @@ #include "mlir/Support/LogicalResult.h" #include +mlir::LogicalResult mlir::daphne::AllAggSumOp::canonicalize(mlir::daphne::AllAggSumOp op, + mlir::PatternRewriter &rewriter) { + mlir::Value input = op.getOperand(); + mlir::Location location = op.getLoc(); + mlir::Type result_type = op.getResult().getType(); + auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext()); + + // Rule 1: sumAll(ewAdd(X, Y)) to ewAdd(sumAll(X), sumAll(Y)) + if (auto addOp = input.getDefiningOp()) { + // Checking the inputs are matrices + if (!addOp.getLhs().getType().isa() || + !addOp.getRhs().getType().isa()) { + return mlir::failure(); + } + + // Individual sums + mlir::Value lSum = rewriter.create(location, unknownType, addOp.getLhs()); + mlir::Value rSum = rewriter.create(location, unknownType, addOp.getRhs()); + mlir::Value scalar_add = rewriter.create(location, result_type, lSum, rSum); + + rewriter.replaceOp(op, scalar_add); + return mlir::success(); + } // Rule 2: sumAll(transpose(X)) to sumAll(X) + else if (auto transOp = input.getDefiningOp()) { + mlir::Value input_tr = transOp.getArg(); + + // Inputs should be matrices + if (!input_tr.getType().isa()) { + return mlir::failure(); + } + + mlir::Value simplf_sumOftranspose = rewriter.create(location, result_type, input_tr); + rewriter.replaceOp(op, simplf_sumOftranspose); + return mlir::success(); + } // Rule 3: sum(lambda * X) -> lambda * sum(X) + else if (auto lambdaMul = input.getDefiningOp()) { + mlir::Value left_o = lambdaMul.getLhs(); + mlir::Value right_o = lambdaMul.getRhs(); + + mlir::Value scalarOperand; + mlir::Value matrixOperand; + + bool lhsIsSca = CompilerUtils::hasScaType(left_o); + bool rhsIsSca = CompilerUtils::hasScaType(right_o); + + // Use .getType() only for matrix detection + bool lhsIsMatrix = left_o.getType().isa(); + bool rhsIsMatrix = right_o.getType().isa(); + + if (lhsIsSca && rhsIsMatrix) { + scalarOperand = left_o; + matrixOperand = right_o; + } else if (rhsIsSca && lhsIsMatrix) { + scalarOperand = right_o; + matrixOperand = left_o; + } else { + return mlir::failure(); // Unsupported combination + } + + mlir::Value innerSum = rewriter.create(location, unknownType, matrixOperand); + mlir::Value newMul = rewriter.create(location, result_type, scalarOperand, innerSum); + rewriter.replaceOp(op, newMul); + + } // Rule 4: trace(X @ Y) = sum(diagVector(X @ Y)) -> sum(X * transpose(Y)) + else if (auto diagVec = input.getDefiningOp()) { + mlir::Value input_dV = diagVec.getOperand(); // This should be a matrix (result of MatMul) + if (auto matMul = input_dV.getDefiningOp()) { + mlir::Value lhs = matMul.getLhs(); + mlir::Value rhs = matMul.getRhs(); + + if (!lhs.getType().isa() || !rhs.getType().isa()) { + return mlir::failure(); + } + + mlir::Value t_rhs = rewriter.create(location, unknownType, rhs); + mlir::Value ewMul_m = rewriter.create(location, unknownType, lhs, t_rhs); + mlir::Value simplifiedSum = rewriter.create(location, result_type, ewMul_m); + + rewriter.replaceOp(op, simplifiedSum); + return mlir::success(); + } + } + + return mlir::failure(); +} + +/** +* @brief Canonicalizes: +1)(X%*%Y)[7,3] → X[7,]%*%Y[,3] + +*/ +mlir::LogicalResult mlir::daphne::SliceColOp::canonicalize(mlir::daphne::SliceColOp op, + mlir::PatternRewriter &rewriter) { + mlir::Value input = op.getOperand(0); + mlir::Location location = op.getLoc(); + mlir::Type result_type = op.getResult().getType(); + auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext()); + + auto sliceRowOp = input.getDefiningOp(); + if (!sliceRowOp) { + return mlir::failure(); + } + + auto matMulOp = sliceRowOp.getOperand(0).getDefiningOp(); + if (!matMulOp) { + return mlir::failure(); + } + + // matrices + mlir::Value X = matMulOp.getLhs(); + mlir::Value Y = matMulOp.getRhs(); + + // lower-upper bounds for rows + mlir::Value row_l = sliceRowOp.getOperand(1); + mlir::Value row_u = sliceRowOp.getOperand(2); + + // lower-upper bounds for columns + mlir::Value col_l = op.getOperand(1); + mlir::Value col_u = op.getOperand(2); + + // to check if a matrix is transposed + mlir::Value t_X = matMulOp.getOperand(2); + mlir::Value t_Y = matMulOp.getOperand(3); + + bool isTransposedX = CompilerUtils::isConstant(t_X).second; + bool isTransposedY = CompilerUtils::isConstant(t_Y).second; + + mlir::Value row; + mlir::Value col; + + if (!isTransposedX && !isTransposedY) { + row = rewriter.create(location, unknownType, X, row_l, row_u); + col = rewriter.create(location, unknownType, Y, col_l, col_u); + + auto newMatMul = rewriter.create(location, result_type, row, col, t_X, t_Y); + rewriter.replaceOp(op, newMatMul.getResult()); + return mlir::success(); + + } else { + return mlir::failure(); + } +} + +/** @brief Canonicalizes: +1)X[a:b, c:d] = Y -> X=Y if dims(X) = dims(Y) +//only for matrices with matching element types +*/ +mlir::LogicalResult mlir::daphne::InsertRowOp::canonicalize(mlir::daphne::InsertRowOp op, + mlir::PatternRewriter &rewriter) { + mlir::Location location = op.getLoc(); + mlir::Type result_type = op.getResult().getType(); + + auto insertCol = op.getIns().getDefiningOp(); + if (!insertCol) { + return mlir::failure(); + } + + auto sliceRow = insertCol.getArg().getDefiningOp(); + if (!sliceRow) { + return mlir::failure(); + } + + mlir::Value sliceInput = sliceRow.getSource(); // X + mlir::Value insertColInput = insertCol.getIns(); // Y + if (!sliceInput.getType().isa() || + !insertColInput.getType().isa()) { + return mlir::failure(); + } + + auto sliceType = sliceInput.getType().dyn_cast(); + auto insertColInputType = insertColInput.getType().dyn_cast(); + auto opResultType = op.getResult().getType().dyn_cast(); + + if (!sliceType || !insertColInputType || !opResultType) { + return mlir::failure(); + } + + if (sliceType.getElementType() != insertColInputType.getElementType()) { + return mlir::failure(); + } + + int64_t numRows_X = sliceType.getNumRows(); + int64_t numCols_X = sliceType.getNumCols(); + int64_t numRows_Y = insertColInputType.getNumRows(); + int64_t numCols_Y = insertColInputType.getNumCols(); + + if (numRows_X == -1 || numCols_X == -1 || numRows_Y == -1 || numCols_Y == -1) { + return mlir::failure(); + } + + if (numRows_X != numRows_Y || numCols_X != numCols_Y) { + return mlir::failure(); + } + + auto renamed = rewriter.create(location, result_type, insertColInput); + rewriter.replaceOp(op, renamed.getResult()); + return mlir::success(); +} + mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphne::VectorizedPipelineOp op, mlir::PatternRewriter &rewriter) { // // Find duplicate inputs diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 13aa07a5a..855a94428 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -228,7 +228,8 @@ def Daphne_MatMulOp : Daphne_Op<"matMul", [ DataTypeMat, ValueTypeFromArgs, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, CUDASupport, FPGAOPENCLSupport, - CastFirstTwoArgsToResType, NoMemoryEffect + CastFirstTwoArgsToResType, NoMemoryEffect, + Pure ]> { let arguments = (ins MatrixOf<[NumScalar]>:$lhs, MatrixOf<[NumScalar]>:$rhs, BoolScalar:$transa, BoolScalar:$transb); let results = (outs MatrixOf<[NumScalar]>:$res); @@ -498,7 +499,9 @@ class Daphne_AllAggOp traits = []> let results = (outs scalarType:$res); } -def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>; +def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>{ + let hasCanonicalizeMethod = 1; +} def Daphne_AllAggMinOp : Daphne_AllAggOp<"minAll", NumScalar, [ValueTypeFromFirstArg]>; def Daphne_AllAggMaxOp : Daphne_AllAggOp<"maxAll", NumScalar, [ValueTypeFromFirstArg]>; def Daphne_AllAggMeanOp : Daphne_AllAggOp<"meanAll", NumScalar, [ValueTypeFromArgsFP]>; @@ -649,7 +652,8 @@ def Daphne_ExtractRowOp : Daphne_Op<"extractRow", [ def Daphne_SliceRowOp : Daphne_Op<"sliceRow", [ TypeFromFirstArg, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + Pure ]> { let summary = "Copies the specified rows from the argument to the result."; @@ -704,6 +708,7 @@ def Daphne_SliceColOp : Daphne_Op<"sliceCol", [ let arguments = (ins MatrixOrFrame:$source, SI64:$lowerIncl, SI64:$upperExcl); let results = (outs MatrixOrFrame:$res); + let hasCanonicalizeMethod = 1; } // TODO Create combined InsertOp (see #238). @@ -714,11 +719,13 @@ def Daphne_InsertRowOp : Daphne_Op<"insertRow", [ ]> { let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$rowLowerIncl, SI64:$rowUpperExcl); let results = (outs MatrixOrFrame:$res); + let hasCanonicalizeMethod = 1; } def Daphne_InsertColOp : Daphne_Op<"insertCol", [ TypeFromFirstArg, // this is debatable - ShapeFromArg + ShapeFromArg, + Pure ]> { let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$colLowerIncl, SI64:$colUpperExcl); let results = (outs MatrixOrFrame:$res); @@ -989,7 +996,8 @@ def Daphne_SoftmaxOp : Daphne_Op<"Softmax", [ DataTypeFromFirstArg, ValueTypeFro // **************************************************************************** def Daphne_DiagVectorOp : Daphne_Op<"diagVector", [ - TypeFromFirstArg, NumRowsFromArg, OneCol + TypeFromFirstArg, NumRowsFromArg, OneCol, + Pure ]> { let arguments = (ins MatrixOrU:$arg); let results = (outs MatrixOrU:$res); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5d20ae772..4432f58be 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,7 +18,8 @@ include_directories(${PROJECT_SOURCE_DIR}/thirdparty/catch2) # for "catch.hpp" set(TEST_SOURCES run_tests.h run_tests.cpp - + + api/cli/expressions/SimplificationTest.cpp api/cli/algorithms/AlgorithmsTest.cpp api/cli/algorithms/DecisionTreeRandomForestTest.cpp api/cli/config/ConfigTest.cpp diff --git a/test/api/cli/expressions/SimplificationTest.cpp b/test/api/cli/expressions/SimplificationTest.cpp new file mode 100644 index 000000000..dd9d3494e --- /dev/null +++ b/test/api/cli/expressions/SimplificationTest.cpp @@ -0,0 +1,24 @@ +#include + +#include + +#include + +#include +#include + +const std::string dirPath = "test/api/cli/expressions/"; + +#define MAKE_TEST_CASE(name, count) \ + TEST_CASE(name, TAG_REWRITE) { \ + for (unsigned i = 1; i <= count; i++) { \ + DYNAMIC_SECTION(name "_" << i << ".daphne") { compareDaphneToRefSimple(dirPath, name, i); } \ + } \ + } + +MAKE_TEST_CASE("simplf_sumEwadd", 1) +MAKE_TEST_CASE("simplf_sumTranspose", 1) +MAKE_TEST_CASE("simplf_sumMulLambda", 1) +MAKE_TEST_CASE("simplf_sumTrace", 1) +MAKE_TEST_CASE("simplf_mmSlice", 1) +MAKE_TEST_CASE("simplf_dynInsert", 1) diff --git a/test/api/cli/expressions/simplf_dynInsert_1.daphne b/test/api/cli/expressions/simplf_dynInsert_1.daphne new file mode 100644 index 000000000..d58594d54 --- /dev/null +++ b/test/api/cli/expressions/simplf_dynInsert_1.daphne @@ -0,0 +1,15 @@ + +m1 = [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, + 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0 +]; +m1 = reshape(m1, 5, 5); + +m2 = fill(0.0, 5, 5); +m2 [0:5, 0:5] = m1; +print(m1); +print(m2); + +m3 = fill(0.0, 7, 8); +m3 [0:5, 0:5] = m1; +print(m3); diff --git a/test/api/cli/expressions/simplf_dynInsert_1.txt b/test/api/cli/expressions/simplf_dynInsert_1.txt new file mode 100644 index 000000000..120fbcfe1 --- /dev/null +++ b/test/api/cli/expressions/simplf_dynInsert_1.txt @@ -0,0 +1,20 @@ +DenseMatrix(5x5, double) +1 2 3 4 5 +6 7 8 9 10 +11 12 13 14 15 +16 17 18 19 20 +21 22 23 24 25 +DenseMatrix(5x5, double) +1 2 3 4 5 +6 7 8 9 10 +11 12 13 14 15 +16 17 18 19 20 +21 22 23 24 25 +DenseMatrix(7x8, double) +1 2 3 4 5 0 0 0 +6 7 8 9 10 0 0 0 +11 12 13 14 15 0 0 0 +16 17 18 19 20 0 0 0 +21 22 23 24 25 0 0 0 +0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 diff --git a/test/api/cli/expressions/simplf_mmSlice_1.daphne b/test/api/cli/expressions/simplf_mmSlice_1.daphne new file mode 100644 index 000000000..029beb470 --- /dev/null +++ b/test/api/cli/expressions/simplf_mmSlice_1.daphne @@ -0,0 +1,82 @@ +// === FLOAT64 Matrices === +A_f = [ 1.0, 2.0, 3.0, 4.0 ]; // [[1,2],[3,4]] +B_f = [ 5.0, 6.0, 7.0, 8.0 ]; // [[5,6],[7,8]] +A_f = reshape(A_f, 2, 2); +B_f = reshape(B_f, 2, 2); + +// === INT64 Matrices === +A_i = [ 1, 2, 3, 4 ]; // [[1,2],[3,4]] +B_i = [ 5, 6, 7, 8 ]; // [[5,6],[7,8]] +A_i = reshape(A_i, 2, 2); +B_i = reshape(B_i, 2, 2); + +// === A_i (int) @ B_f (float) === +Aif_Bf_1 = (A_i @B_f)[1, 0]; +Aif_Bf_2 = (transpose(A_i) @transpose(B_f))[1, 0]; +Aif_Bf_3 = (A_i @transpose(B_f))[1, 0]; +Aif_Bf_4 = (transpose(A_i) @B_f)[1, 0]; + +// === A_f (float) @ B_i (int) === +Af_Bi_1 = (A_f @B_i)[1, 0]; +Af_Bi_2 = (transpose(A_f) @transpose(B_i))[1, 0]; +Af_Bi_3 = (A_f @transpose(B_i))[1, 0]; +Af_Bi_4 = (transpose(A_f) @B_i)[1, 0]; + +// === A_f @ B_f === +Aff_Bf_1 = (A_f @B_f)[1, 0]; +Aff_Bf_2 = (transpose(A_f) @transpose(B_f))[1, 0]; +Aff_Bf_3 = (A_f @transpose(B_f))[1, 0]; +Aff_Bf_4 = (transpose(A_f) @B_f)[1, 0]; + +// === A_i @ B_i === +Aii_Bi_1 = (A_i @B_i)[1, 0]; +Aii_Bi_2 = (transpose(A_i) @transpose(B_i))[1, 0]; +Aii_Bi_3 = (A_i @transpose(B_i))[1, 0]; +Aii_Bi_4 = (transpose(A_i) @B_i)[1, 0]; + +// === PRINT RESULTS === +print(as.scalar(Aff_Bf_1)); // Expect: 43 +print(as.scalar(Aff_Bf_2)); // Expect: 31 +print(as.scalar(Aff_Bf_3)); // Expect: 39 +print(as.scalar(Aff_Bf_4)); // Expect: 38 + +print(as.scalar(Aii_Bi_1)); // Expect: 43 +print(as.scalar(Aii_Bi_2)); // Expect: 31 +print(as.scalar(Aii_Bi_3)); // Expect: 39 +print(as.scalar(Aii_Bi_4)); // Expect: 38 + +print(as.scalar(Aif_Bf_1)); // Expect: 43 +print(as.scalar(Aif_Bf_2)); // Expect: 31 +print(as.scalar(Aif_Bf_3)); // Expect: 39 +print(as.scalar(Aif_Bf_4)); // Expect: 38 + +print(as.scalar(Af_Bi_1)); // Expect: 43 +print(as.scalar(Af_Bi_2)); // Expect: 31 +print(as.scalar(Af_Bi_3)); // Expect: 39 +print(as.scalar(Af_Bi_4)); // Expect: 38 + +// === With more than one row/col === +// Define two 5x5 matrices +m1 = [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, + 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0 +]; +m1 = reshape(m1, 5, 5); + +m2 = [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, + 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0 +]; +m2 = reshape(m2, 5, 5); + +// === Valid sliced matrix multiplications === +mm1 = (m1 @m2)[1:3, 2:4]; // plain +mm2 = (transpose(m1) @transpose(m2))[1:4, 2:4]; +mm3 = (m1 @transpose(m2))[1:3, 2:4]; +mm4 = (transpose(m1) @m2)[1:3, 2:4]; + +// Print the resulting submatrices (2x2) +print(mm1); +print(mm2); +print(mm3); +print(mm4); diff --git a/test/api/cli/expressions/simplf_mmSlice_1.txt b/test/api/cli/expressions/simplf_mmSlice_1.txt new file mode 100644 index 000000000..da4e09431 --- /dev/null +++ b/test/api/cli/expressions/simplf_mmSlice_1.txt @@ -0,0 +1,29 @@ +43 +34 +39 +38 +43 +34 +39 +38 +43 +34 +39 +38 +43 +34 +39 +38 +DenseMatrix(2x2, double) +570 610 +895 960 +DenseMatrix(3x2, double) +830 1130 +895 1220 +960 1310 +DenseMatrix(2x2, double) +530 730 +855 1180 +DenseMatrix(2x2, double) +1030 1090 +1095 1160 diff --git a/test/api/cli/expressions/simplf_sumEwadd_1.daphne b/test/api/cli/expressions/simplf_sumEwadd_1.daphne new file mode 100644 index 000000000..9a0045997 --- /dev/null +++ b/test/api/cli/expressions/simplf_sumEwadd_1.daphne @@ -0,0 +1,20 @@ +// This triggers canonicalization: sum(X + Y) => sum(X) + sum(Y) +X = [ 1.0, 2.0, 3.0, 4.0 ]; +X = reshape(X, 2, 2); + +Y = [ 4.0, 3.0, 2.0, 1.0 ]; +Y = reshape(Y, 2, 2); + +res = sum(X + Y); +print(as.scalar(res)); + +// with different element types +Q = [ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]; +Q = reshape(Q, 3, 3); + +Z = [ 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5 ]; +Z = reshape(Z, 3, 3); + +// Sum of elementwise addition +res_2 = sum(Q + Z); +print(as.scalar(res_2)); diff --git a/test/api/cli/expressions/simplf_sumEwadd_1.txt b/test/api/cli/expressions/simplf_sumEwadd_1.txt new file mode 100644 index 000000000..055a84325 --- /dev/null +++ b/test/api/cli/expressions/simplf_sumEwadd_1.txt @@ -0,0 +1,2 @@ +20 +85.5 diff --git a/test/api/cli/expressions/simplf_sumMulLambda_1.daphne b/test/api/cli/expressions/simplf_sumMulLambda_1.daphne new file mode 100644 index 000000000..849008b5a --- /dev/null +++ b/test/api/cli/expressions/simplf_sumMulLambda_1.daphne @@ -0,0 +1,31 @@ +#f64 scalar *f64 matrix +a1 = 2.5; +X1 = [ 1.5, 2.0, 3.0, 4.0 ]; +X1 = reshape(X1, 2, 2); +res1 = sum(a1 * X1); +print(as.scalar(res1)); +#Expected : 26.25 + +#int scalar *f64 matrix + a2 = 2; +X2 = [ 1.5, 2.5, 3.5, 4.5 ]; +X2 = reshape(X2, 2, 2); +res2 = sum(a2 * X2); +print(as.scalar(res2)); +#Expected : 24 + +#f64 scalar *int matrix + a3 = 1.5; +X3 = [ 1, 2, 3, 4 ]; +X3 = reshape(X3, 2, 2); +res3 = sum(a3 * X3); +print(as.scalar(res3)); +#Expected : 15 + +#int scalar *int matrix + a4 = 3; +X4 = [ 2, 4, 6, 8 ]; +X4 = reshape(X4, 2, 2); +res4 = sum(a4 * X4); +print(as.scalar(res4)); +#Expected : 60 diff --git a/test/api/cli/expressions/simplf_sumMulLambda_1.txt b/test/api/cli/expressions/simplf_sumMulLambda_1.txt new file mode 100644 index 000000000..6c336cbde --- /dev/null +++ b/test/api/cli/expressions/simplf_sumMulLambda_1.txt @@ -0,0 +1,4 @@ +26.25 +24 +15 +60 diff --git a/test/api/cli/expressions/simplf_sumTrace_1.daphne b/test/api/cli/expressions/simplf_sumTrace_1.daphne new file mode 100644 index 000000000..d2ebb30d0 --- /dev/null +++ b/test/api/cli/expressions/simplf_sumTrace_1.daphne @@ -0,0 +1,39 @@ +// Float matrices (f64) +A_f64 = [ 1.0, 2.0, 3.0, 4.0 ]; +A_f64 = reshape(A_f64, 2, 2); + +B_f64 = [ 5.0, 6.0, 7.0, 8.0 ]; +B_f64 = reshape(B_f64, 2, 2); + +res1 = sum(diagVector(A_f64 @B_f64)); +print(as.scalar(res1), true, false); // Expect: 69 + +// Integer matrices (si64) +A_si64 = [ 1, 2, 3, 4 ]; +A_si64 = reshape(A_si64, 2, 2); + +B_si64 = [ 5, 6, 7, 8 ]; +B_si64 = reshape(B_si64, 2, 2); + +res2 = sum(diagVector(A_si64 @B_si64)); +print(as.scalar(res2), true, false); // Expect: 69 + +// Mixed: si64 @ f64 +A_mix1 = [ 1, 2, 3, 4 ]; // si64 +A_mix1 = reshape(A_mix1, 2, 2); + +B_mix1 = [ 5.0, 6.0, 7.0, 8.0 ]; // f64 +B_mix1 = reshape(B_mix1, 2, 2); + +res3 = sum(diagVector(A_mix1 @B_mix1)); +print(as.scalar(res3), true, false); // Expect: 69 + +// Mixed: f64 @ si64 +A_mix2 = [ 1.0, 2.0, 3.0, 4.0 ]; // f64 +A_mix2 = reshape(A_mix2, 2, 2); + +B_mix2 = [ 5, 6, 7, 8 ]; // si64 +B_mix2 = reshape(B_mix2, 2, 2); + +res4 = sum(diagVector(A_mix2 @B_mix2)); +print(as.scalar(res4), true, false); // Expect: 69 diff --git a/test/api/cli/expressions/simplf_sumTrace_1.txt b/test/api/cli/expressions/simplf_sumTrace_1.txt new file mode 100644 index 000000000..4914f8307 --- /dev/null +++ b/test/api/cli/expressions/simplf_sumTrace_1.txt @@ -0,0 +1,4 @@ +69 +69 +69 +69 diff --git a/test/api/cli/expressions/simplf_sumTranspose_1.daphne b/test/api/cli/expressions/simplf_sumTranspose_1.daphne new file mode 100644 index 000000000..62aaa54fa --- /dev/null +++ b/test/api/cli/expressions/simplf_sumTranspose_1.daphne @@ -0,0 +1,20 @@ +// This triggers canonicalization: sum(transpose(X)) => sum(X) + +// Matrix X: floats +X = [ 1.25, 2.5, 3.0, 4.0, 5.5, 6.0, 3.0, 0.0, 0.0 ]; +X = reshape(X, 3, 3); // Expect: 25.25 + +res = sum(transpose(X)); +print(as.scalar(res)); + +// Matrix B: integers +B = [ 1, 2, 3, 4 ]; +B = reshape(B, 2, 2); +res_B = sum(transpose(B)); +print(as.scalar(res_B)); // Expect: 10 + +// Matrix C: mixed values +C = [ 0.5, 1.5, 2.5, 3.5 ]; +C = reshape(C, 2, 2); +res_C = sum(transpose(C)); +print(as.scalar(res_C)); // Expect: 8 diff --git a/test/api/cli/expressions/simplf_sumTranspose_1.txt b/test/api/cli/expressions/simplf_sumTranspose_1.txt new file mode 100644 index 000000000..40b64567c --- /dev/null +++ b/test/api/cli/expressions/simplf_sumTranspose_1.txt @@ -0,0 +1,3 @@ +25.25 +10 +8 diff --git a/test/tags.h b/test/tags.h index 098de8ef0..0055244a5 100644 --- a/test/tags.h +++ b/test/tags.h @@ -22,6 +22,7 @@ // tag macros separated by whitespace, e.g., if TAG_A is "[a]" and TAG_B is // "[b]", then TAG_A TAG_B is "[a]" "[b]", which is equivalent to "[a][b]". +#define TAG_REWRITE "[rewrite]" #define TAG_ALGORITHMS "[algorithms]" #define TAG_CAST "[cast]" #define TAG_CODEGEN "[codegen]" diff --git a/test/util/simplf_dynInsert.mlir b/test/util/simplf_dynInsert.mlir new file mode 100644 index 000000000..acf352365 --- /dev/null +++ b/test/util/simplf_dynInsert.mlir @@ -0,0 +1,38 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK + +module { + func.func @main() { + %c0 = "daphne.constant"() {value = 0 : si64} : () -> si64 + %c5 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %c10 = "daphne.constant"() {value = 1.000000e+01 : f64} : () -> f64 + %c100 = "daphne.constant"() {value = 1.000000e+02 : f64} : () -> f64 + %cone = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + %true = "daphne.constant"() {value = true} : () -> i1 + %false = "daphne.constant"() {value = false} : () -> i1 + + %idx5 = "daphne.cast"(%c5) : (si64) -> index + %zero = "daphne.constant"() {value = 0.000000e+00 : f64} : () -> f64 + + %X = "daphne.fill"(%zero, %idx5, %idx5) : (f64, index, index) -> !daphne.Matrix<5x5xf64> + %Y = "daphne.randMatrix"(%idx5, %idx5, %c10, %c100, %cone, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<5x5xf64> + + %slice = "daphne.sliceRow"(%X, %c0, %c5) : (!daphne.Matrix<5x5xf64>, si64, si64) -> !daphne.Matrix<5x5xf64> + %insertCol = "daphne.insertCol"(%slice, %Y, %c0, %c5) + : (!daphne.Matrix<5x5xf64>, !daphne.Matrix<5x5xf64>, si64, si64) -> !daphne.Matrix<5x5xf64> + %insertRow = "daphne.insertRow"(%X, %insertCol, %c0, %c5) + : (!daphne.Matrix<5x5xf64>, !daphne.Matrix<5x5xf64>, si64, si64) -> !daphne.Matrix<5x5xf64> + + "daphne.print"(%insertRow, %true, %false) : (!daphne.Matrix<5x5xf64>, i1, i1) -> () + "daphne.return"() : () -> () + } +} + +// CHECK-LABEL: func.func @main() +// CHECK: %[[RAND:.*]] = "daphne.randMatrix" +// CHECK-NOT: daphne.insertRow +// CHECK-NOT: daphne.insertCol +// CHECK-NOT: daphne.sliceRow +// CHECK: "daphne.print"(%[[RAND]], {{.*}}, {{.*}}) + diff --git a/test/util/simplf_dynInsert_fail.mlir b/test/util/simplf_dynInsert_fail.mlir new file mode 100644 index 000000000..03d47768f --- /dev/null +++ b/test/util/simplf_dynInsert_fail.mlir @@ -0,0 +1,43 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK + +module { + func.func @main() { + %c0 = "daphne.constant"() {value = 0 : si64} : () -> si64 + %c5 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %c7 = "daphne.constant"() {value = 7 : si64} : () -> si64 + %c8 = "daphne.constant"() {value = 8 : si64} : () -> si64 + %c10 = "daphne.constant"() {value = 1.000000e+01 : f64} : () -> f64 + %c100 = "daphne.constant"() {value = 1.000000e+02 : f64} : () -> f64 + %cone = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + %true = "daphne.constant"() {value = true} : () -> i1 + %false = "daphne.constant"() {value = false} : () -> i1 + + %idx7 = "daphne.cast"(%c7) : (si64) -> index + %idx8 = "daphne.cast"(%c8) : (si64) -> index + %idx5 = "daphne.cast"(%c5) : (si64) -> index + + %zero = "daphne.constant"() {value = 0.000000e+00 : f64} : () -> f64 + + %X = "daphne.fill"(%zero, %idx7, %idx8) : (f64, index, index) -> !daphne.Matrix<7x8xf64> + %Y = "daphne.randMatrix"(%idx5, %idx5, %c10, %c100, %cone, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<5x5xf64> + + %slice = "daphne.sliceRow"(%X, %c0, %c5) + : (!daphne.Matrix<7x8xf64>, si64, si64) -> !daphne.Matrix + + %insertCol = "daphne.insertCol"(%slice, %Y, %c0, %c5) + : (!daphne.Matrix, !daphne.Matrix<5x5xf64>, si64, si64) -> !daphne.Matrix + + %insertRow = "daphne.insertRow"(%X, %insertCol, %c0, %c5) + : (!daphne.Matrix<7x8xf64>, !daphne.Matrix, si64, si64) -> !daphne.Matrix + + "daphne.print"(%insertRow, %true, %false) : (!daphne.Matrix, i1, i1) -> () + "daphne.return"() : () -> () + } +} + +// CHECK: daphne.sliceRow +// CHECK: daphne.insertCol +// CHECK: daphne.insertRow + diff --git a/test/util/simplf_mmSlice.mlir b/test/util/simplf_mmSlice.mlir new file mode 100644 index 000000000..34fda0536 --- /dev/null +++ b/test/util/simplf_mmSlice.mlir @@ -0,0 +1,50 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK -dump-input=always + +module { + func.func @main() { + // X @ Y (no transpose) + %rX = "daphne.constant"() {value = 4 : index} : () -> index + %cX = "daphne.constant"() {value = 6 : index} : () -> index + %cY = "daphne.constant"() {value = 5 : index} : () -> index + %low = "daphne.constant"() {value = 0.0 : f64} : () -> f64 + %high = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %fill = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + + %X = "daphne.randMatrix"(%rX, %cX, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<4x6xf64> + %Y = "daphne.randMatrix"(%cX, %cY, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<6x5xf64> + + %f = "daphne.constant"() {value = false} : () -> i1 + + %matmul = "daphne.matMul"(%X, %Y, %f, %f) + : (!daphne.Matrix<4x6xf64>, !daphne.Matrix<6x5xf64>, i1, i1) -> !daphne.Matrix<4x5xf64> + + %i3 = "daphne.constant"() {value = 3 : si64} : () -> si64 + %i1 = "daphne.constant"() {value = 4 : si64} : () -> si64 + %i3p = "daphne.constant"() {value = 4 : si64} : () -> si64 + %i1p = "daphne.constant"() {value = 5 : si64} : () -> si64 + + %slice_row = "daphne.sliceRow"(%matmul, %i3, %i3p) + : (!daphne.Matrix<4x5xf64>, si64, si64) -> !daphne.Matrix<1x5xf64> + %slice_col = "daphne.sliceCol"(%slice_row, %i1, %i1p) + : (!daphne.Matrix<1x5xf64>, si64, si64) -> !daphne.Matrix<1x1xf64> + + %bool_1 = "daphne.constant"() {value = true} : () -> i1 + %bool_2 = "daphne.constant"() {value = false} : () -> i1 + + "daphne.print"(%slice_col, %bool_1, %bool_2) : (!daphne.Matrix<1x1xf64>, i1, i1) -> () + "daphne.return"() : () -> () + } +} + + +// CHECK: %[[X:.*]] = "daphne.randMatrix"({{.*}}) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<4x6xf64> +// CHECK: %[[Y:.*]] = "daphne.randMatrix"({{.*}}) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<6x5xf64> +// CHECK: %[[ROW:.*]] = "daphne.sliceRow"(%[[X]], {{.*}}, {{.*}}) : (!daphne.Matrix<4x6xf64>, si64, si64) -> !daphne.Unknown +// CHECK: %[[COL:.*]] = "daphne.sliceCol"(%[[Y]], {{.*}}, {{.*}}) : (!daphne.Matrix<6x5xf64>, si64, si64) -> !daphne.Unknown +// CHECK: "daphne.matMul"(%[[ROW]], %[[COL]], {{.*}}, {{.*}}) : (!daphne.Unknown, !daphne.Unknown, i1, i1) -> !daphne.Matrix<1x1xf64> +// CHECK-NOT: "daphne.sliceCol"(%{{.*matMul.*}}) +// CHECK-NOT: "daphne.sliceRow"(%{{.*matMul.*}}) + diff --git a/test/util/simplf_sumEwadd.mlir b/test/util/simplf_sumEwadd.mlir new file mode 100644 index 000000000..4b7b76e3e --- /dev/null +++ b/test/util/simplf_sumEwadd.mlir @@ -0,0 +1,35 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK -dump-input=always + +module { + func.func @main() { + %c1 = "daphne.constant"() {value = 4 : index} : () -> index + %c2 = "daphne.constant"() {value = 1 : index} : () -> index + %low = "daphne.constant"() {value = 0.0 : f64} : () -> f64 + %high = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %fill = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + %printFlag1 = "daphne.constant"() {value = true} : () -> i1 + %printFlag2 = "daphne.constant"() {value = false} : () -> i1 + + %A = "daphne.randMatrix"(%c1, %c2, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<4x1xf64> + %B = "daphne.randMatrix"(%c1, %c2, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<4x1xf64> + + %add = "daphne.ewAdd"(%A, %B) + : (!daphne.Matrix<4x1xf64>, !daphne.Matrix<4x1xf64>) -> !daphne.Matrix<4x1xf64> + + %sum = "daphne.sumAll"(%add) : (!daphne.Matrix<4x1xf64>) -> f64 + + "daphne.print"(%sum, %printFlag1, %printFlag2) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} + +// CHECK: %[[A:.*]] = "daphne.randMatrix" +// CHECK: %[[B:.*]] = "daphne.randMatrix" +// CHECK: %[[A_SUM:.*]] = "daphne.sumAll"(%[[A]]) : (!daphne.Matrix<4x1xf64>) -> !daphne.Unknown +// CHECK: %[[B_SUM:.*]] = "daphne.sumAll"(%[[B]]) : (!daphne.Matrix<4x1xf64>) -> !daphne.Unknown +// CHECK: "daphne.ewAdd"(%[[A_SUM]], %[[B_SUM]]) : (!daphne.Unknown, !daphne.Unknown) -> f64 +// CHECK-NOT: "daphne.sumAll"({{.*ewAdd.*}}) + diff --git a/test/util/simplf_sumMulLambda.mlir b/test/util/simplf_sumMulLambda.mlir new file mode 100644 index 000000000..4fc68760c --- /dev/null +++ b/test/util/simplf_sumMulLambda.mlir @@ -0,0 +1,28 @@ + + +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK -dump-input=always + +module { + func.func @main() { + %0 = "daphne.constant"() {value = 3 : index} : () -> index + %1 = "daphne.constant"() {value = 4 : index} : () -> index + %2 = "daphne.constant"() {value = 4 : si64} : () -> si64 + %3 = "daphne.constant"() {value = 1.000000e+02 : f64} : () -> f64 + %4 = "daphne.constant"() {value = 3.000000e+02 : f64} : () -> f64 + %5 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %6 = "daphne.randMatrix"(%1, %0, %3, %4, %5, %2) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix + %7 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %8 = "daphne.ewMul"(%7, %6) : (si64, !daphne.Matrix) -> !daphne.Matrix + %9 = "daphne.sumAll"(%8) : (!daphne.Matrix) -> f64 + %10 = "daphne.constant"() {value = true} : () -> i1 + %11 = "daphne.constant"() {value = false} : () -> i1 + "daphne.print"(%9, %10, %11) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} + +// CHECK: "daphne.constant"() {value = 5 : si64} : () -> si64 +// CHECK: %[[MATRIX:.*]] = "daphne.randMatrix"({{.*}}) +// CHECK: %[[SUM:.*]] = "daphne.sumAll"(%[[MATRIX]]) +// CHECK: "daphne.ewMul"(%[[SUM]], %{{.*}}) : (!daphne.Unknown, si64) -> f64 +// CHECK-NOT: "daphne.sumAll"({{.*}}ewMul) diff --git a/test/util/simplf_sumTrace.mlir b/test/util/simplf_sumTrace.mlir new file mode 100644 index 000000000..89dfd13f9 --- /dev/null +++ b/test/util/simplf_sumTrace.mlir @@ -0,0 +1,41 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK -dump-input=always + +module { + func.func @main() { + %rows = "daphne.constant"() {value = 2 : index} : () -> index + %cols = "daphne.constant"() {value = 2 : index} : () -> index + %low = "daphne.constant"() {value = 0.0 : f64} : () -> f64 + %high = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %fill = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + + %lhs = "daphne.randMatrix"(%rows, %cols, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<2x2xf64> + %rhs = "daphne.randMatrix"(%rows, %cols, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<2x2xf64> + + %tX = "daphne.constant"() {value = false} : () -> i1 + %tY = "daphne.constant"() {value = false} : () -> i1 + + %mm = "daphne.matMul"(%lhs, %rhs, %tX, %tY) + : (!daphne.Matrix<2x2xf64>, !daphne.Matrix<2x2xf64>, i1, i1) -> !daphne.Matrix<2x2xf64> + + %diag = "daphne.diagVector"(%mm) + : (!daphne.Matrix<2x2xf64>) -> !daphne.Matrix<2x1xf64> + + %sum = "daphne.sumAll"(%diag) + : (!daphne.Matrix<2x1xf64>) -> f64 + + "daphne.return"() : () -> () + } +} + + +// CHECK: %[[LHS:.*]] = "daphne.randMatrix"({{.*}}) : {{.*}} -> !daphne.Matrix<2x2xf64> +// CHECK: %[[RHS:.*]] = "daphne.randMatrix"({{.*}}) : {{.*}} -> !daphne.Matrix<2x2xf64> +// CHECK: %[[TRHS:.*]] = "daphne.transpose"(%[[RHS]]) : (!daphne.Matrix<2x2xf64>) -> !daphne.Unknown +// CHECK: %[[EWMUL:.*]] = "daphne.ewMul"(%[[LHS]], %[[TRHS]]) : (!daphne.Matrix<2x2xf64>, !daphne.Unknown) -> !daphne.Unknown +// CHECK: "daphne.sumAll"(%[[EWMUL]]) : (!daphne.Unknown) -> f64 +// CHECK-NOT: daphne.matMul +// CHECK-NOT: daphne.diagVector + diff --git a/test/util/simplf_sumTranspose.mlir b/test/util/simplf_sumTranspose.mlir new file mode 100644 index 000000000..c43b0d4ba --- /dev/null +++ b/test/util/simplf_sumTranspose.mlir @@ -0,0 +1,31 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s --check-prefix=CHECK -dump-input=always + +module { + func.func @main() { + %c1 = "daphne.constant"() {value = 2 : index} : () -> index + %c2 = "daphne.constant"() {value = 3 : index} : () -> index + %low = "daphne.constant"() {value = 0.0 : f64} : () -> f64 + %high = "daphne.constant"() {value = 10.0 : f64} : () -> f64 + %fill = "daphne.constant"() {value = 1.0 : f64} : () -> f64 + %seed = "daphne.constant"() {value = 42 : si64} : () -> si64 + %true = "daphne.constant"() {value = true} : () -> i1 + %false = "daphne.constant"() {value = false} : () -> i1 + + %A = "daphne.randMatrix"(%c1, %c2, %low, %high, %fill, %seed) + : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<2x3xf64> + + %AT = "daphne.transpose"(%A) + : (!daphne.Matrix<2x3xf64>) -> !daphne.Matrix<3x2xf64> + + %s = "daphne.sumAll"(%AT) + : (!daphne.Matrix<3x2xf64>) -> f64 + + "daphne.print"(%s, %true, %false) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} + +// CHECK: %[[A:.*]] = "daphne.randMatrix" +// CHECK-NOT: "daphne.transpose" +// CHECK: "daphne.sumAll"(%[[A]]) : (!daphne.Matrix<2x3xf64>) -> f64 +