diff --git a/scripts/examples/gemm-codegen.daph b/scripts/examples/gemm-codegen.daph new file mode 100644 index 000000000..3bf641822 --- /dev/null +++ b/scripts/examples/gemm-codegen.daph @@ -0,0 +1,15 @@ +# bench.daph +size=$size; +sparsity=$sparsity; + +alpha = 2; +beta = 3; +A = rand(size, size, 1.0, 1.0, sparsity, -1); +B = rand(size, size, 1.0, 1.0, sparsity, -1); +C = rand(size, size, 1.0, 1.0, sparsity, -1); +start = now(); +D = beta * C + alpha * A @ B ; +end = now(); +print((end-start) / 1000000000.0); +x = aggMax(D); +print(x); \ No newline at end of file diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 0e5c0c4fd..d35e78d1b 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -32,6 +32,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -186,6 +187,9 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { if (userConfig_.explain_kernels) pm.addPass(mlir::daphne::createPrintIRPass("IR after kernel lowering:")); + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); pm.addNestedPass(mlir::LLVM::createRequestCWrappersPass()); pm.addPass(mlir::daphne::createLowerToLLVMPass(userConfig_)); @@ -271,6 +275,12 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createAggDimOpLoweringPass()); pm.addPass(mlir::daphne::createMapOpLoweringPass()); pm.addPass(mlir::daphne::createTransposeOpLoweringPass()); + + //pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); + //pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); + pm.addPass(mlir::daphne::createSliceOpLoweringPass()); + //pm.addPass(mlir::daphne::createExtractOpLoweringPass()); + pm.addPass(mlir::createInlinerPass()); pm.addNestedPass(mlir::createLoopFusionPass()); @@ -304,6 +314,8 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { mlir::LowerVectorToLLVMOptions lowerVectorToLLVMOptions; pm.addPass(mlir::createConvertVectorToLLVMPass(lowerVectorToLLVMOptions)); + + if (userConfig_.explain_mlir_codegen) pm.addPass(mlir::daphne::createPrintIRPass("IR after codegen pipeline")); if (userConfig_.explain_mlir_codegen_mlir_specific) diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index fcbc1c995..9b169396f 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -36,6 +36,10 @@ add_mlir_dialect_library(MLIRDaphneTransforms TransposeOpLowering.cpp SparsityExploitationPass.cpp + SliceRowOpLowering.cpp + SliceColOpLowering.cpp + SliceOpLowering.cpp + DEPENDS MLIRDaphneOpsIncGen MLIRDaphneTransformsIncGen diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 256795529..64ecda24c 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" @@ -46,6 +47,7 @@ #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +using namespace std; // **************************************************************************** // Rewriter Templates (Elemwise Unary, Elemwise Binary) @@ -77,6 +79,63 @@ template struct UnaryOpLowering : publi return mlir::success(); } + LogicalResult matchAndRewriteSparseMat(UnaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + + Location loc = op->getLoc(); + auto sparseMatType = adaptor.getArg().getType().template dyn_cast(); + Type matrixElementType = sparseMatType.getElementType(); + ssize_t numRows = sparseMatType.getNumRows(); + ssize_t numCols = sparseMatType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + } + + MemRefType sparseValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + + Value argValuesMemref = rewriter.create( + loc, sparseValuesMemRefType, adaptor.getArg()); + + Value one = rewriter.create(loc, 1); + Value resMemref = rewriter.create( + loc, sparseValuesMemRefType, ValueRange{one}); + + SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(1, rewriter.getContext())}; + SmallVector iterTypes = {utils::IteratorType::parallel}; + + rewriter.create( + loc, TypeRange{}, ValueRange{argValuesMemref}, ValueRange{resMemref}, indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + Value resValue = unaryFunc(OpBuilderNested, locNested, this->typeConverter, arg[0]); + OpBuilderNested.create(locNested, resValue); + }); + + MemRefType sparseColIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseRowOffsetsMemRefType = MemRefType::get({numRows + 1}, rewriter.getIndexType()); + + Value argColIdxsMemref = rewriter.create( + loc, sparseColIdxsMemRefType, adaptor.getArg()); + Value argRowOffsetsMemref = rewriter.create( + loc, sparseRowOffsetsMemRefType, adaptor.getArg()); + + Value maxNumRowsValue = rewriter.create(loc, numRows); + Value numColsValue = rewriter.create(loc, numCols); + Value maxNumNonZerosValue = rewriter.create(loc, numCols * numRows); + + auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resMemref, argColIdxsMemref, argRowOffsetsMemref, + //maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + maxNumRowsValue, numColsValue, maxNumNonZerosValue, adaptor.getArg().getType()); + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } + LogicalResult matchAndRewrite(UnaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); @@ -87,6 +146,10 @@ template struct UnaryOpLowering : publi return matchAndRewriteScalarVal(op, adaptor, rewriter); } + if (matrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) { + return matchAndRewriteSparseMat(op, adaptor, rewriter); + } + Type matrixElementType = matrixType.getElementType(); ssize_t numRows = matrixType.getNumRows(); ssize_t numCols = matrixType.getNumCols(); @@ -216,6 +279,44 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { Type matrixElementType = lhsMatrixType.getElementType(); + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) + { + MemRefType valuesMemRefType = MemRefType::get({ShapedType::kDynamic}, matrixElementType); + MemRefType colIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType rowOffsetsMemRefType = MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemref = rewriter.create(loc, valuesMemRefType, lhs); + auto lhsColIdxsMemref = rewriter.create(loc, colIdxsMemRefType, lhs); + auto lhsRowOffsetsMemref = rewriter.create(loc, rowOffsetsMemRefType, lhs); + + Value one = rewriter.create(loc, 1); + Value resMemref = rewriter.create(loc, valuesMemRefType, ValueRange{one}); + + SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(1, rewriter.getContext())}; + SmallVector iterTypes = {utils::IteratorType::parallel}; + + rewriter.create( + loc, TypeRange{}, ValueRange{lhsValuesMemref}, ValueRange{resMemref}, indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + Value resValue = binaryFunc(OpBuilderNested, locNested, this->typeConverter, arg[0], rhs); + OpBuilderNested.create(locNested, resValue); + }); + + Value maxNumRowsValue = rewriter.create(loc, lhsRows); + Value numColsValue = rewriter.create(loc, lhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, lhsCols * lhsRows); + + auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resMemref, lhsColIdxsMemref, lhsRowOffsetsMemref, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + + } + MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); @@ -238,6 +339,496 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { return mlir::success(); } + LogicalResult matchAndRewriteSparseDenseMat(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + auto sparseLhsMatrixType = lhs.getType().template dyn_cast(); + auto denseRhsMatrixType = rhs.getType().template dyn_cast(); + + ssize_t sparseLhsRows = sparseLhsMatrixType.getNumRows(); + ssize_t sparseLhsCols = sparseLhsMatrixType.getNumCols(); + ssize_t denseRhsRows = denseRhsMatrixType.getNumRows(); + ssize_t denseRhsCols = denseRhsMatrixType.getNumCols(); + + MemRefType sparseLhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, sparseLhsMatrixType.getElementType()); + MemRefType sparseLhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseLhsRowOffsetsMemRefType = + MemRefType::get({sparseLhsRows + 1}, rewriter.getIndexType()); + MemRefType denseRhsMemRefType = + MemRefType::get({denseRhsRows, denseRhsCols}, denseRhsMatrixType.getElementType()); + + auto sparseLhsValuesMemRef = + rewriter.create(loc, sparseLhsValuesMemRefType, lhs); + auto sparseLhsColIdxsMemRef = + rewriter.create(loc, sparseLhsColIdxsMemRefType, lhs); + auto sparseLhsRowOffsetsMemRef = + rewriter.create(loc, sparseLhsRowOffsetsMemRefType, lhs); + auto denseRhsMemRef = + rewriter.create(loc, denseRhsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numSparseLhsRowsValue = rewriter.create(loc, sparseLhsRows); + + // return a dense matrix if the op is add + auto resDenseMemRef = rewriter.create(loc, denseRhsMemRefType); + rewriter.create(loc, denseRhsMemRef, resDenseMemRef); + // return a sparse matrix if the op is mul + auto resSparseMemRef = rewriter.create(loc, sparseLhsValuesMemRefType, ValueRange{one}); + + rewriter.create( + loc, zero, numSparseLhsRowsValue, one, ValueRange{}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopInvariants) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + + auto colIdxLowerIncl = OpBuilderNested.create( + locNested, sparseLhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto colIdxUpperExcl = OpBuilderNested.create( + locNested, sparseLhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + OpBuilderNested.create( + locNested, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, ValueRange loopInvariants) + { + auto rowIdx = rowPtr; + auto colIdx = OpBuilderTwiceNested.create( + locTwiceNested, sparseLhsColIdxsMemRef, ValueRange{loopIdxNested}); + + auto sparseLhsValue = OpBuilderTwiceNested.create( + locTwiceNested, sparseLhsValuesMemRef, ValueRange{loopIdxNested}); + + auto denseRhsValue = OpBuilderTwiceNested.create( + locTwiceNested, denseRhsMemRef, ValueRange{rowIdx, colIdx}); + + Value resValue = binaryFunc( + OpBuilderTwiceNested, locTwiceNested, this->typeConverter, sparseLhsValue, denseRhsValue); + + if (llvm::isa(op)) + { + OpBuilderTwiceNested.create( + locTwiceNested, resValue, resDenseMemRef, ValueRange{rowIdx, colIdx}); + } + else if (llvm::isa(op)) + { + OpBuilderTwiceNested.create( + locTwiceNested, resValue, resSparseMemRef, ValueRange{loopIdxNested}); + } + else + { + throw ErrorHandler::compilerError(loc, "EwOpsLowering (BinaryOp)", "Unsupported ewOps codegen"); + } + OpBuilderTwiceNested.create(locTwiceNested, resValue); + } + ); + OpBuilderNested.create(locNested); + } + ); + + if (llvm::isa(op)) + { + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resDenseMemRef, op.getType()); + rewriter.replaceOp(op, resDenseMatrix); + + return mlir::success(); + } + else if (llvm::isa(op)) + { + Value maxNumRowsValue = rewriter.create(loc, sparseLhsRows); + Value numColsValue = rewriter.create(loc, sparseLhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, sparseLhsCols * sparseLhsRows); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resSparseMemRef, sparseLhsColIdxsMemRef, sparseLhsRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + rewriter.replaceOp(op, resCSRMatrix); + return mlir::success(); + } + else + { + throw ErrorHandler::compilerError(loc, "EwOpsLowering (BinaryOp)", "Unsupported ewOps codegen"); + } + } + + LogicalResult matchAndRewriteSparseSparseMat(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + auto lhsMatrixType = lhs.getType().template dyn_cast(); + auto rhsMatrixType = rhs.getType().template dyn_cast(); + + ssize_t lhsRows = lhsMatrixType.getNumRows(); + ssize_t lhsCols = lhsMatrixType.getNumCols(); + ssize_t rhsRows = rhsMatrixType.getNumRows(); + ssize_t rhsCols = rhsMatrixType.getNumCols(); + + if (lhsRows != rhsRows || lhsCols != rhsCols) + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp Sparse Sparse)", "lhs and rhs must have the same dimensions."); + + auto numRows = lhsRows; + auto numCols = lhsCols; + + MemRefType lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, lhsMatrixType.getElementType()); + MemRefType rhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, rhsMatrixType.getElementType()); + MemRefType colIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType rowOffsetsMemRefType = + MemRefType::get({numRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, colIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, rowOffsetsMemRefType, lhs); + auto rhsValuesMemRef = + rewriter.create(loc, rhsValuesMemRefType, rhs); + auto rhsColIdxsMemRef = + rewriter.create(loc, colIdxsMemRefType, rhs); + auto rhsRowOffsetsMemRef = + rewriter.create(loc, rowOffsetsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numRowsValue = rewriter.create(loc, numRows); + + auto resValuesMemRef = rewriter.create(loc, lhsValuesMemRefType, ValueRange{one}); + auto resColIdxsMemRef = rewriter.create(loc, colIdxsMemRefType, ValueRange{one}); + auto resRowOffsetsMemRef = rewriter.create(loc, rowOffsetsMemRefType); + rewriter.create(loc, zero, resRowOffsetsMemRef, ValueRange{zero}); + + rewriter.create( + loc, zero, numRowsValue, one, ValueRange{zero}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + + auto resValuesPtr = loopIterArgs[0]; + + auto lhsColIdxLowerIncl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxUpperExcl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + auto rhsColIdxLowerIncl = OpBuilderNested.create( + locNested, rhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto rhsColIdxUpperExcl = OpBuilderNested.create( + locNested, rhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto lhsColIdxUpperIncl = OpBuilderNested.create( + locNested, lhsColIdxUpperExcl, one); + auto lhsColUpper = OpBuilderNested.create( + locNested, lhsColIdxsMemRef, ValueRange{lhsColIdxUpperIncl}); + auto rhsColIdxUpperIncl = OpBuilderNested.create( + locNested, rhsColIdxUpperExcl, one); + auto rhsColUpper = OpBuilderNested.create( + locNested, rhsColIdxsMemRef, ValueRange{rhsColIdxUpperIncl}); + + + auto lhsEndFirst = OpBuilderNested.create( + locNested, arith::CmpIPredicate::ult, lhsColUpper, rhsColUpper); + + auto lhsAllZero = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, lhsColIdxLowerIncl, lhsColIdxUpperExcl); + auto rhsAllZero = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, rhsColIdxLowerIncl, rhsColIdxUpperExcl); + + auto operation = OpBuilderNested.create( + locNested, lhsAllZero, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + auto thenRegion = OpBuilderTwiceNested.create( + locTwiceNested, rhsAllZero, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + //if lhs and rhs are all-zero in this row, move to next row + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + //if lhs is all-zero in this row but rhs is not + if (llvm::isa(op)){ + auto forLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsColIdxLowerIncl, rhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + //copy this row of rhs to the res memref if the op is add + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + } + else + { + //else move to the next row + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + } + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{thenRegion.getResult(0)}); + }, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + auto elseRegion = OpBuilderTwiceNested.create( + locNested, rhsAllZero, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + if (llvm::isa(op)){ + auto forLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxLowerIncl, lhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + } + else + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + } + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto whileLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, + TypeRange{ + OpBuilderThreetimesNested.getIndexType(), + OpBuilderThreetimesNested.getIndexType(), + OpBuilderThreetimesNested.getIndexType()}, + ValueRange{lhsColIdxLowerIncl, rhsColIdxLowerIncl, resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, ValueRange args) + { + auto cond1 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, args[0], lhsColIdxUpperExcl); + auto cond2 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, args[1], rhsColIdxUpperExcl); + auto cond = OpBuilderFourtimesNested.create(locFourtimesNested, cond1, cond2); + OpBuilderFourtimesNested.create(locFourtimesNested, cond, args); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, ValueRange args) + { + auto lhsCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{args[0]}); + auto rhsCol = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsColIdxsMemRef, ValueRange{args[1]}); + + auto case1 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); + auto case2 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); + // copy the element whose col num is smaller to the res if the op is add + // then load the next element of that side + auto newArg = OpBuilderFourtimesNested.create( + locFourtimesNested, case1, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested) + { + auto newResValuesPtr = args[2]; + if (llvm::isa(op)) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{args[0]}); + auto resIndex = args[2]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, lhsCol, resColIdxsMemRef, ValueRange{resIndex}); + newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + } + + auto newArg0 = OpBuilderFivetimesNested.create(locFivetimesNested, args[0], one); + auto newArg1 = args[1]; + OpBuilderFivetimesNested.create( + locFivetimesNested, + ValueRange{newArg0, newArg1, newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested) + { + auto case2Region = OpBuilderFivetimesNested.create( + locFivetimesNested, case2, + [&](OpBuilder &OpBuilderSixtimesNested, Location locSixtimesNested) + { + auto newResValuesPtr = args[2]; + if (llvm::isa(op)) + { + auto resValue = OpBuilderSixtimesNested.create( + locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + auto resIndex = args[2]; + OpBuilderSixtimesNested.create( + locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderSixtimesNested.create( + locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + newResValuesPtr = OpBuilderSixtimesNested.create( + locSixtimesNested, resIndex, one); + } + auto newArg0 = args[0]; + auto newArg1 = OpBuilderSixtimesNested.create(locSixtimesNested, args[1], one); + OpBuilderSixtimesNested.create( + locSixtimesNested, ValueRange{newArg0, newArg1, newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderSixtimesNested, Location locSixtimesNested) + { + //perform computation on elements if their num col is equal to each other + auto lhsValue = OpBuilderSixtimesNested.create( + locSixtimesNested, lhsValuesMemRef, ValueRange{args[0]}); + auto rhsValue = OpBuilderSixtimesNested.create( + locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + auto resValue = binaryFunc( + OpBuilderSixtimesNested, locSixtimesNested, this->typeConverter, lhsValue, rhsValue); + auto resIndex = args[2]; + OpBuilderSixtimesNested.create( + locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderSixtimesNested.create( + locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderSixtimesNested.create( + locSixtimesNested, resIndex, one); + auto newArg0 = OpBuilderSixtimesNested.create(locSixtimesNested, args[0], one); + auto newArg1 = OpBuilderSixtimesNested.create(locSixtimesNested, args[1], one); + OpBuilderSixtimesNested.create( + locSixtimesNested, ValueRange{newArg0, newArg1, newResValuesPtr}); + } + ); + OpBuilderFivetimesNested.create( + locFivetimesNested, + ValueRange{case2Region.getResult(0), case2Region.getResult(1), case2Region.getResult(2)}); + } + ); + auto newArg0 = newArg.getResult(0); + auto newArg1 = newArg.getResult(1); + auto newArg2 = newArg.getResult(2); + + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newArg0, newArg1, newArg2}); + } + ); + // if lhs ends first, the rest will be in rhs and copy them to the res if the op is add + if (llvm::isa(op)) + { + auto rest = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsEndFirst, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto rhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{rhsRest.getResult(0)}); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto lhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{lhsRest.getResult(0)}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); + } + else //TODO: Support ops other than add + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{whileLoop.getResult(2)}); + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{elseRegion.getResult(0)}); + } + ); + + OpBuilderNested.create( + locNested, + operation.getResult(0), + resRowOffsetsMemRef, + ValueRange{nextRowPtr}); + + OpBuilderNested.create(locNested, ValueRange{operation.getResult(0)}); + } + ); + + Value maxNumRowsValue = rewriter.create(loc, numRows); + Value numColsValue = rewriter.create(loc, numCols); + Value maxNumNonZerosValue = rewriter.create(loc, numCols * numRows); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resValuesMemRef, resColIdxsMemRef, resRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } + LogicalResult matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.getLhs(); @@ -255,6 +846,14 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhs); } + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) + return matchAndRewriteSparseSparseMat(op, adaptor, rewriter); + Type matrixElementType = lhsMatrixType.getElementType(); ssize_t lhsRows = lhsMatrixType.getNumRows(); @@ -467,7 +1066,7 @@ struct EwOpLoweringPass : public mlir::PassWrapper(); + daphne::DaphneDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect>(); } void runOnOperation() final; @@ -528,7 +1127,7 @@ void EwOpLoweringPass::runOnOperation() { target.addLegalDialect(); + mlir::math::MathDialect, mlir::linalg::LinalgDialect, mlir::scf::SCFDialect>(); // UnaryOps target.addDynamicallyLegalOp(); - if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + if (matType && (matType.getRepresentation() == daphne::MatrixRepresentation::Dense || matType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { return false; } return true; @@ -567,6 +1166,22 @@ void EwOpLoweringPass::runOnOperation() { (rhsMatType && rhsMatType.getRepresentation() == daphne::MatrixRepresentation::Dense)) { return false; } + + if ((llvm::isa(rhs) || llvm::isa(rhs)) && + (lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { + return false; + } + + if ((lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse) && + (rhsMatType && rhsMatType.getRepresentation() == daphne::MatrixRepresentation::Dense)) { + return false; + } + + if ((lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse) && + (rhsMatType && rhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { + return false; + } + return true; }); diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index c0aef46db..c4576290c 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -300,13 +300,355 @@ class MatMulLowering : public OpConversionPattern { loops.push_back(fmaLoop); return loops; } + + template + Value binaryWithConversionFunc(OpBuilder &rewriter, Location loc, TypeConverter *typeConverter, Value lhs, Value rhs) const { + Type resType = lhs.getType(); + Value res{}; + if (llvm::isa(resType)) { + lhs = convertToSignlessInt(rewriter, loc, typeConverter, lhs, resType); + rhs = convertToSignlessInt(rewriter, loc, typeConverter, rhs, resType); + res = rewriter.create(loc, lhs, rhs).getResult(); + res = typeConverter->materializeTargetConversion(rewriter, loc, resType, res); + } else { + res = rewriter.create(loc, lhs, rhs).getResult(); + } + return res; + } + + template + Value cmpWithConversionFunc(OpBuilder &rewriter, Location loc, TypeConverter *typeConverter, Value lhs, Value rhs) const { + Type resType = lhs.getType(); + Value res{}; + if (llvm::isa(resType)) { + lhs = convertToSignlessInt(rewriter, loc, typeConverter, lhs, resType); + rhs = convertToSignlessInt(rewriter, loc, typeConverter, rhs, resType); + res = rewriter.create(loc, cmpIPredicate, lhs, rhs).getResult(); + res = typeConverter->materializeTargetConversion(rewriter, loc, resType, res); + } else { + res = rewriter.create(loc, cmpFPredicate, lhs, rhs).getResult(); + } + return res; + } + + Value csrIndex(OpBuilder &rewriter, Location loc, + Value valuesMemRef, Value colIdxsMemRef, Value rowOffsetsMemRef, Value row, Value col, Type type) const + { + auto zeroElem = rewriter.create(loc, rewriter.getZeroAttr(type)); + auto one = rewriter.create(loc, 1); + auto zero = rewriter.create(loc, 0); + auto rowPtr = row; + auto nextRowPtr = rewriter.create(loc, row, one); + auto colIdxLowerIncl = rewriter.create( + loc, rowOffsetsMemRef, ValueRange{rowPtr}); + auto colIdxUpperExcl = rewriter.create( + loc, rowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto search = rewriter.create( + loc,TypeRange{ + rewriter.getIndexType(), + rewriter.getIndexType(), + type}, + ValueRange{colIdxLowerIncl, one, zeroElem}, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange args) + { + auto cond1 = OpBuilderNested.create( + locNested, arith::CmpIPredicate::ult, args[0], colIdxUpperExcl); + auto cond2 = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, args[1], one); + auto cond = OpBuilderNested.create(locNested, cond1, cond2); + OpBuilderNested.create(locNested, cond, args); + }, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange args) + { + auto getCol = OpBuilderNested.create(locNested, colIdxsMemRef, ValueRange{args[0]}); + + auto cond = OpBuilderNested.create(locNested, arith::CmpIPredicate::eq, getCol, col); + // return the value of non-zero element if exists, else return a zero value + auto res = OpBuilderNested.create( + locNested, cond, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + auto getValue = OpBuilderTwiceNested.create( + locTwiceNested, valuesMemRef, ValueRange{args[0]}); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{zero, getValue}); + }, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{one, zeroElem}); + } + ); + auto nextPtr = OpBuilderNested.create(locNested, args[0], one); + OpBuilderNested.create(locNested, ValueRange{nextPtr, res.getResult(0), res.getResult(1)}); + } + ); + + return search.getResult(2); + } + + LogicalResult matchAndRewriteSparseSparseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + mlir::daphne::MatrixType lhsMatrixType = lhs.getType().dyn_cast(); + mlir::daphne::MatrixType rhsMatrixType = rhs.getType().dyn_cast(); + + auto lhsRows = lhsMatrixType.getNumRows(); + auto lhsCols = lhsMatrixType.getNumCols(); + + auto rhsRows = rhsMatrixType.getNumRows(); + auto rhsCols = rhsMatrixType.getNumCols(); + + auto matrixElementType = lhsMatrixType.getElementType(); + auto lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto lhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto lhsRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + auto resValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto resColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto resRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, lhs); + auto rhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, rhs); + auto rhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, rhs); + auto rhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto zeroElement = rewriter.create(loc, rewriter.getZeroAttr(matrixElementType)); + auto numLhsRowsValue = rewriter.create(loc, lhsRows); + auto numRhsColsValue = rewriter.create(loc, rhsCols); + + auto resValuesMemRef = rewriter.create(loc, resValuesMemRefType, ValueRange{one}); + auto resColIdxsMemRef = rewriter.create(loc, resColIdxsMemRefType, ValueRange{one}); + auto resRowOffsetsMemRef = rewriter.create(loc, resRowOffsetsMemRefType); + rewriter.create(loc, zero, resRowOffsetsMemRef, ValueRange{zero}); + + auto lhsRowLoop = rewriter.create( + loc, zero, numLhsRowsValue, one, ValueRange{zero}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto lhsRow = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, loopIdx, one); + auto resValuesPtr = loopIterArgs[0]; + + auto lhsColIdxLowerIncl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxUpperExcl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto rhsColLoop = OpBuilderNested.create( + locNested, zero, numRhsColsValue, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValuesPtr = loopIterArgs[0]; + auto rhsCol = loopIdx; + auto lhsColLoop = OpBuilderTwiceNested.create( + locTwiceNested, lhsColIdxLowerIncl, lhsColIdxUpperExcl, one, ValueRange{zeroElement}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto acc = loopIterArgs[0]; + + auto lhsElemRow = lhsRow; + auto lhsElemCol = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto lhsElemValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + + auto rhsElemRow = lhsElemCol; + auto rhsElemCol = rhsCol; + //auto rhsElemCol = lhsElemRow; + // locate the required element in rhs corresponding to the lhs element + auto rhsElemValue = csrIndex( + OpBuilderThreetimesNested, locThreetimesNested, + rhsValuesMemRef, rhsColIdxsMemRef, rhsRowOffsetsMemRef, + rhsElemRow, rhsElemCol, matrixElementType); + + auto product = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, lhsElemValue, rhsElemValue); + auto newAcc = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, product, acc); + + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newAcc}); + + } + ); + auto cond = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpFPredicate::OEQ, lhsColLoop.getResult(0), zeroElement); + // store the result if it is not zero + auto newPtr = OpBuilderTwiceNested.create( + locTwiceNested, cond, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto newResValuesPtr = OpBuilderThreetimesNested.create( + locThreetimesNested, resValuesPtr, one); + OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColLoop.getResult(0), resValuesMemRef, ValueRange{resValuesPtr}); + OpBuilderThreetimesNested.create( + locThreetimesNested, rhsCol, resColIdxsMemRef, ValueRange{resValuesPtr}); + OpBuilderThreetimesNested.create( + locThreetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{newPtr.getResult(0)}); + } + ); + auto newResValuesPtr = rhsColLoop.getResult(0); + + OpBuilderNested.create( + locNested, + newResValuesPtr, + resRowOffsetsMemRef, + ValueRange{nextRowPtr}); + OpBuilderNested.create(locNested, ValueRange{newResValuesPtr}); + } + ); + + Value maxNumRowsValue = rewriter.create(loc, lhsRows); + Value numColsValue = rewriter.create(loc, rhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, lhsRows * rhsCols); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resValuesMemRef, resColIdxsMemRef, resRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } + + LogicalResult matchAndRewriteSparseDenseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + mlir::daphne::MatrixType lhsMatrixType = lhs.getType().dyn_cast(); + mlir::daphne::MatrixType rhsMatrixType = rhs.getType().dyn_cast(); + + auto lhsRows = lhsMatrixType.getNumRows(); + auto lhsCols = lhsMatrixType.getNumCols(); + + auto rhsRows = rhsMatrixType.getNumRows(); + auto rhsCols = rhsMatrixType.getNumCols(); + + auto matrixElementType = lhsMatrixType.getElementType(); + + auto lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto lhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto lhsRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + auto rhsMemRefType = mlir::MemRefType::get({rhsRows, rhsCols}, matrixElementType); + auto resMemRefType = mlir::MemRefType::get({lhsRows, rhsCols}, matrixElementType); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, lhs); + auto rhsMemRef = + rewriter.create(loc, rhsMemRefType, rhs); + auto resMemRef = rewriter.create(loc, resMemRefType); + + auto zeroElement = rewriter.create(loc, rewriter.getZeroAttr(matrixElementType)); + rewriter.create(loc, ValueRange{zeroElement}, ValueRange{resMemRef}); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numLhsRowsValue = rewriter.create(loc, lhsRows); + auto numRhsColsValue = rewriter.create(loc, rhsCols); + + auto lhsRowLoop = rewriter.create( + loc, zero, numLhsRowsValue, one, ValueRange{}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + auto rhsColLoop = OpBuilderNested.create( + locNested, zero, numRhsColsValue, one, ValueRange{}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rhsCol = loopIdx; + auto lhsColIdxsLowerIncl = OpBuilderTwiceNested.create( + locTwiceNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxsUpperExcl = OpBuilderTwiceNested.create( + locTwiceNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto resValueLoop = OpBuilderTwiceNested.create( + locTwiceNested, lhsColIdxsLowerIncl, lhsColIdxsUpperExcl, one, ValueRange{zeroElement}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto lhsValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto rhsRow = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto rhsValue = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsMemRef, ValueRange{rhsRow, rhsCol}); + + auto resValue = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, lhsValue, rhsValue); + + auto accResValue = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, loopIterArgs[0], resValue); + + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{accResValue}); + } + ); + OpBuilderTwiceNested.create( + locTwiceNested, resValueLoop.getResult(0), resMemRef, ValueRange{rowPtr, rhsCol}); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{}); + } + ); + OpBuilderNested.create(locNested, ValueRange{}); + } + ); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemRef, rhs.getType()); + rewriter.replaceOp(op, resDenseMatrix); + return mlir::success(); + } + LogicalResult matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); mlir::daphne::MatrixType lhsMatrixType = adaptor.getLhs().getType().dyn_cast(); mlir::daphne::MatrixType rhsMatrixType = adaptor.getRhs().getType().dyn_cast(); + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) + return matchAndRewriteSparseSparseMat(op, adaptor, rewriter); + auto lhsRows = lhsMatrixType.getNumRows(); auto lhsCols = lhsMatrixType.getNumCols(); diff --git a/src/compiler/lowering/SliceColOpLowering.cpp b/src/compiler/lowering/SliceColOpLowering.cpp new file mode 100644 index 000000000..5d19ddcf3 --- /dev/null +++ b/src/compiler/lowering/SliceColOpLowering.cpp @@ -0,0 +1,178 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +//template +class SliceColOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + explicit SliceColOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceColOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(daphne::SliceColOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceColOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().getDefiningOp().getValue().dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().getDefiningOp().getValue().dyn_cast().getSInt(); + + Value resMemref = rewriter.create(loc, MemRefType::get({numRows, (upperExcl-lowerIncl)}, matrixElementType)); + + DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({0, lowerIncl}); + DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; + + SmallVector iterTypes{utils::IteratorType::parallel, + utils::IteratorType::parallel}; + + rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, + indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + OpBuilderNested.create(locNested, arg[0]); + }); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct SliceColLoweringPass : public mlir::PassWrapper> { + explicit SliceColLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice-col"; } + StringRef getDescription() const final { return "Lowers SliceCol operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceColLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceColOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/compiler/lowering/SliceOpLowering.cpp b/src/compiler/lowering/SliceOpLowering.cpp new file mode 100644 index 000000000..57197a211 --- /dev/null +++ b/src/compiler/lowering/SliceOpLowering.cpp @@ -0,0 +1,171 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +static constexpr size_t ROW = 0; +static constexpr size_t COL = 1; + +template +class SliceOpLowering : public OpConversionPattern { + public: + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + explicit SliceOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceOpLowering"); + } + + /** + * @brief Replaces a Slice operation with a MemRef SubviewOp if possible. + * + * @return mlir::success if Slice has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(SliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().template dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().template getDefiningOp().getValue().template dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().template getDefiningOp().getValue().template dyn_cast().getSInt(); + + DenseI64ArrayAttr offset = sliceAlongDim == ROW ? rewriter.getDenseI64ArrayAttr({lowerIncl, 0}) + : rewriter.getDenseI64ArrayAttr({0, lowerIncl}); + + DenseI64ArrayAttr sizes = sliceAlongDim == ROW ? rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}) + : rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); + + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + Value resMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +using SliceRowOpLowering = SliceOpLowering; +using SliceColOpLowering = SliceOpLowering; + +namespace { +/** + * @brief Lowers the daphne::Slice operator to a Memref SubviewOp. + */ +struct SliceLoweringPass : public mlir::PassWrapper> { + explicit SliceLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice"; } + StringRef getDescription() const final { return "Lowers SliceRow/SliceCol operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/compiler/lowering/SliceRowOpLowering.cpp b/src/compiler/lowering/SliceRowOpLowering.cpp new file mode 100644 index 000000000..b831c43b0 --- /dev/null +++ b/src/compiler/lowering/SliceRowOpLowering.cpp @@ -0,0 +1,179 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +//template +class SliceRowOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + explicit SliceRowOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceRowOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(daphne::SliceRowOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceRowOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().getDefiningOp().getValue().dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().getDefiningOp().getValue().dyn_cast().getSInt(); + + // Value resMemref = rewriter.create(loc, MemRefType::get({(upperExcl-lowerIncl), numCols}, matrixElementType)); + + DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({lowerIncl, 0}); + DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}); + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + // Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + Value resMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + // SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + // AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; + + // SmallVector iterTypes{utils::IteratorType::parallel, + // utils::IteratorType::parallel}; + + // rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, + // indexMaps, iterTypes, + // [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + // OpBuilderNested.create(locNested, arg[0]); + // }); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct SliceRowLoweringPass : public mlir::PassWrapper> { + explicit SliceRowLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice-row"; } + StringRef getDescription() const final { return "Lowers SliceRow operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceRowLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceRowOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/compiler/utils/LoweringUtils.cpp b/src/compiler/utils/LoweringUtils.cpp index 74fba85ab..089c58ac6 100644 --- a/src/compiler/utils/LoweringUtils.cpp +++ b/src/compiler/utils/LoweringUtils.cpp @@ -85,6 +85,24 @@ mlir::Value convertMemRefToDenseMatrix(mlir::Location loc, mlir::ConversionPatte strides[0], strides[1]); } +mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Value valuesMemRef, mlir::Value colIdxsMemRef, mlir::Value rowOffsetsMemRef, + mlir::Value maxNumRows, mlir::Value numCols, mlir::Value maxNumNonZeros, mlir::Type type) +{ + //auto extractStridedMetadataOp = rewriter.create(loc, memRef); + // aligned ptr (memref.data) + mlir::Value alignedValuesPtr = rewriter.create + (loc, valuesMemRef); + mlir::Value alignedColIdxsPtr = rewriter.create + (loc, colIdxsMemRef); + mlir::Value alignedRowOffsetsPtr = rewriter.create + (loc, rowOffsetsMemRef); + + return rewriter.create(loc, type, + alignedValuesPtr, alignedColIdxsPtr, alignedRowOffsetsPtr, + maxNumRows, numCols, maxNumNonZeros); +} + mlir::Type convertFloat(mlir::FloatType floatType) { return mlir::IntegerType::get(floatType.getContext(), floatType.getIntOrFloatBitWidth()); } diff --git a/src/compiler/utils/LoweringUtils.h b/src/compiler/utils/LoweringUtils.h index a723492f9..80137a264 100644 --- a/src/compiler/utils/LoweringUtils.h +++ b/src/compiler/utils/LoweringUtils.h @@ -43,6 +43,10 @@ void affineFillMemRef(mlir::Value value, mlir::ConversionPatternRewriter &rewrit mlir::Value convertMemRefToDenseMatrix(mlir::Location, mlir::ConversionPatternRewriter &, mlir::Value memRef, mlir::Type); +mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Value valuesMemRef, mlir::Value colIdxsMemRef, mlir::Value rowOffsetsMemRef, + mlir::Value maxNumRows, mlir::Value numCols, mlir::Value maxNumNonZeros, mlir::Type type); + llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc); diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index b2d91310b..94b6557af 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -83,6 +83,15 @@ def Daphne_ConvertMemRefToDenseMatrix : Daphne_Op<"convertMemRefToDenseMatrix"> let results = (outs MatrixOrU:$res); } +def Daphne_ConvertMemRefToCSRMatrix : Daphne_Op<"convertMemRefToCSRMatrix"> { + let summary = "Return a CSRMatrix."; + let description = [{ Constructs a DenseMatrix given 3 rank 2 StridedMemRefType. }]; + + /* let arguments = (ins AnyMemRef:$arg); */ + let arguments = (ins Size:$baseValues, Size:$baseColIdxs, Size:$baseRowOffsets, Size:$maxNumRows, Size:$numCols, Size:$maxNumNonZeros); + let results = (outs MatrixOrU:$res); +} + def Daphne_ConvertDenseMatrixToMemRef : Daphne_Op<"convertDenseMatrixToMemRef", [Pure]> { let summary = "Given a DenseMatrix, return a StridedMemRefType."; let description = [{ Constructs a StridedMemRefType with rank 2 from a DenseMatrix* with already allocated memory. }]; diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index 464aeb577..9f20e7ed1 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -71,6 +71,11 @@ std::unique_ptr createSelectMatrixRepresentationsPass(const DaphneUserConf std::unique_ptr createSpecializeGenericFunctionsPass(const DaphneUserConfig &cfg); std::unique_ptr createTransposeOpLoweringPass(); std::unique_ptr createVectorizeComputationsPass(); + +std::unique_ptr createSliceRowOpLoweringPass(); +std::unique_ptr createSliceColOpLoweringPass(); +std::unique_ptr createSliceOpLoweringPass(); + #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); #endif diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index 0f5ab3144..9aad66829 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -259,5 +259,16 @@ def LowerEwOpPass: Pass<"lower-ew", "::mlir::func::FuncOp"> { def SparsityExploitationPass: Pass<"lower-sparse-exploit", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createSparsityExploitationPass()"; } +def SliceRowOpLoweringPass: Pass<"lower-slice-row", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceRowOpLoweringPass()"; +} + +def SliceColOpLoweringPass: Pass<"lower-slice-col", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceColOpLoweringPass()"; +} + +def SliceOpLoweringPass: Pass<"lower-slice", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceOpLoweringPass()"; +} #endif // SRC_IR_DAPHNEIR_PASSES_TD diff --git a/src/runtime/local/datastructures/CSRMatrix.h b/src/runtime/local/datastructures/CSRMatrix.h index 0298622bb..9c00e3499 100644 --- a/src/runtime/local/datastructures/CSRMatrix.h +++ b/src/runtime/local/datastructures/CSRMatrix.h @@ -129,6 +129,16 @@ template class CSRMatrix : public Matrix { rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); } + CSRMatrix(size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, + std::shared_ptr &values, std::shared_ptr &colIdxs, std::shared_ptr &rowOffsets) + : Matrix(maxNumRows, numCols), numRowsAllocated(maxNumRows), isRowAllocatedBefore(false), + maxNumNonZeros(maxNumNonZeros), lastAppendedRowIdx(0) { + + this->values = values; + this->colIdxs = colIdxs; + this->rowOffsets = rowOffsets; + } + virtual ~CSRMatrix() { // nothing to do } diff --git a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h new file mode 100644 index 000000000..4cce7d5d5 --- /dev/null +++ b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h @@ -0,0 +1,37 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline void convertMemRefToCSRMatrix(CSRMatrix *&result, + size_t baseValuesPtr, size_t baseColIdxsPtr, size_t baseRowOffsetsPtr, + size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, DCTX(ctx)) +{ + auto no_op_deleter_1 = [](T *) {}; + auto no_op_deleter_2 = [](size_t *) {}; + T *valuePtr = reinterpret_cast(baseValuesPtr); + size_t *colIdxsPtr = reinterpret_cast(baseColIdxsPtr); + size_t *rowOffsetsPtr = reinterpret_cast(baseRowOffsetsPtr); + std::shared_ptr ptrValues(valuePtr, no_op_deleter_1); + std::shared_ptr ptrColIdxs(colIdxsPtr, no_op_deleter_2); + std::shared_ptr ptrRowOffsets(rowOffsetsPtr, no_op_deleter_2); + result = DataObjectFactory::create>( + maxNumRows, numCols, maxNumNonZeros, ptrValues, ptrColIdxs, ptrRowOffsets); +} diff --git a/src/runtime/local/kernels/EwBinaryMat.h b/src/runtime/local/kernels/EwBinaryMat.h index 5330c71d7..80934d48d 100644 --- a/src/runtime/local/kernels/EwBinaryMat.h +++ b/src/runtime/local/kernels/EwBinaryMat.h @@ -341,3 +341,67 @@ template struct EwBinaryMat, Matrix, Matrix> { res->finishAppend(); } }; + +// ---------------------------------------------------------------------------- +// DenseMatrix <- CSRMatrix, DenseMatrix +// ---------------------------------------------------------------------------- + +template struct EwBinaryMat, CSRMatrix, DenseMatrix> { + static void apply(BinaryOpCode opCode, DenseMatrix *&res, const CSRMatrix *lhs, const DenseMatrix *rhs, + DCTX(ctx)) { + const size_t numRows = lhs->getNumRows(); + const size_t numCols = lhs->getNumCols(); + // TODO: lhs broadcast + // if ((numRows != rhs->getNumRows() && rhs->getNumRows() != 1) || + // (numCols != rhs->getNumCols() && rhs->getNumCols() != 1)) + // throw std::runtime_error("EwBinaryMat(CSR) - lhs and rhs must have " + // "the same dimensions (or broadcast)"); + if (numRows != rhs->getNumRows() || numCols != rhs->getNumCols()) + throw std::runtime_error("EwBinaryMat(CSR) - lhs and rhs must have " + "the same dimensions (or broadcast)"); + + size_t maxNnz; + switch (opCode) { + case BinaryOpCode::ADD: // merge + maxNnz = lhs->getNumNonZeros(); + break; + default: + throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode"); + } + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, false); + + auto *valuesRes = res->getValues(); + auto *valuesRhs = rhs->getValues(); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) + valuesRes[r * numCols + c] = valuesRhs[r * numCols + c]; + + EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); + + switch (opCode) { + case BinaryOpCode::ADD: { // merge non-zero cells + for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) { + size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx); + if (nnzRowLhs) { + // merge within row + const VT *valuesRowLhs = lhs->getValues(rowIdx); + const size_t *colIdxsRowLhs = lhs->getColIdxs(rowIdx); + for (size_t posLhs = 0; posLhs < nnzRowLhs; ++posLhs) { + auto rhsCol = colIdxsRowLhs[posLhs]; + valuesRes[rhsCol] = func(valuesRes[rhsCol], valuesRowLhs[posLhs], ctx); + } + } + valuesRes += res->getRowSkip(); + } + break; + } + default: + throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode"); + } + + // TODO Update number of non-zeros in result in the end. + } +}; diff --git a/src/runtime/local/kernels/EwBinaryObjSca.h b/src/runtime/local/kernels/EwBinaryObjSca.h index 62d704ded..f83837195 100644 --- a/src/runtime/local/kernels/EwBinaryObjSca.h +++ b/src/runtime/local/kernels/EwBinaryObjSca.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -154,4 +155,45 @@ template struct EwBinaryObjSca { } }; +// ---------------------------------------------------------------------------- +// CSRMatrix <- CSRMatrix, scalar +// ---------------------------------------------------------------------------- + +template +struct EwBinaryObjSca, CSRMatrix, VTRhs> { + static void apply(BinaryOpCode opCode, CSRMatrix *&res, const CSRMatrix *lhs, VTRhs rhs, + DCTX(ctx)) { + + if (opCode != BinaryOpCode::MUL) + throw std::runtime_error("EwBinaryObjSca::apply: only support MUL for CSR Matrix"); + + const size_t numRows = lhs->getNumRows(); + const size_t numCols = lhs->getNumCols(); + const size_t maxNumNonZeros = lhs->getMaxNumNonZeros(); + const size_t numNonZeros = lhs->getNumNonZeros(); + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, maxNumNonZeros, false); + + const VTLhs *valuesLhs = lhs->getValues(); + const size_t *colIdxsLhs = lhs->getColIdxs(); + const size_t *rowOffsetsLhs = lhs->getRowOffsets(); + VTRes *valuesRes = res->getValues(); + size_t *colIdxsRes = res->getColIdxs(); + size_t *rowOffsetsRes = res->getRowOffsets(); + + for (size_t i = 0; i < numNonZeros; i++) + colIdxsRes[i] = colIdxsLhs[i]; + + for (size_t i = 0; i < numRows + 1; i++) + rowOffsetsRes[i] = rowOffsetsLhs[i]; + + EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); + + for (size_t i = 0; i < numNonZeros; i++) + valuesRes[i] = func(valuesLhs[i], rhs, ctx); + + } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_EWBINARYOBJSCA_H diff --git a/src/runtime/local/kernels/EwUnaryMat.h b/src/runtime/local/kernels/EwUnaryMat.h index 1524587cb..ac34789b6 100644 --- a/src/runtime/local/kernels/EwUnaryMat.h +++ b/src/runtime/local/kernels/EwUnaryMat.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -94,4 +95,40 @@ template struct EwUnaryMat, Matrix> { } }; +// ---------------------------------------------------------------------------- +// CSRMatrix <- CSRMatrix +// ---------------------------------------------------------------------------- + +template struct EwUnaryMat, CSRMatrix> { + static void apply(UnaryOpCode opCode, CSRMatrix *&res, const CSRMatrix *arg, DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + const size_t numCols = arg->getNumCols(); + const size_t maxNumNonZeros = arg->getMaxNumNonZeros(); + const size_t numNonZeros = arg->getNumNonZeros(); + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, maxNumNonZeros, false); + + const VT *valuesArg = arg->getValues(); + const size_t *colIdxsArg = arg->getColIdxs(); + const size_t *rowOffsetsArg = arg->getRowOffsets(); + + VT *valuesRes = res->getValues(); + size_t *colIdxsRes = res->getColIdxs(); + size_t *rowOffsetsRes = res->getRowOffsets(); + + for (size_t i = 0; i < numNonZeros; i++) + colIdxsRes[i] = colIdxsArg[i]; + + for (size_t i = 0; i < numRows + 1; i++) + rowOffsetsRes[i] = rowOffsetsArg[i]; + + EwUnaryScaFuncPtr func = getEwUnaryScaFuncPtr(opCode); + + for (size_t i = 0; i < numNonZeros; i++) + valuesRes[i] = func(valuesArg[i], ctx); + + } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_EWUNARYMAT_H diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 19597655e..90abc8cd8 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -1511,6 +1511,60 @@ ["double"] ] }, + { + "kernelTemplate": { + "header": "ConvertMemRefToCSRMatrix.h", + "opName": "convertMemRefToCSRMatrix", + "returnType": "void", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *&", + "name": "result" + }, + { + "type": "size_t", + "name": "baseValuesPtr" + }, + { + "type": "size_t", + "name": "baseColIdxsPtr" + }, + { + "type": "size_t", + "name": "baseRowOffsetsPtr" + }, + { + "type": "size_t", + "name": "maxNumRows" + }, + { + "type": "size_t", + "name": "numCols" + }, + { + "type": "size_t", + "name": "maxNumNonZeros" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["size_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, { "kernelTemplate": { "header": "ConvertDenseMatrixToMemRef.h", @@ -2068,6 +2122,16 @@ ["DenseMatrix", "int64_t"], ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"] + ], + [ + ["DenseMatrix", "double"], + ["CSRMatrix", "double"], + ["DenseMatrix", "double"] + ], + [ + ["DenseMatrix", "float"], + ["CSRMatrix", "float"], + ["DenseMatrix", "float"] ] ], "opCodes": [ @@ -2164,6 +2228,36 @@ ["DenseMatrix", "uint64_t"], ["DenseMatrix", "uint64_t"], "uint64_t" + ], + [ + ["CSRMatrix", "float"], + ["CSRMatrix", "float"], + "float" + ], + [ + ["CSRMatrix", "double"], + ["CSRMatrix", "double"], + "double" + ], + [ + ["CSRMatrix", "int64_t"], + ["CSRMatrix", "int64_t"], + "int64_t" + ], + [ + ["CSRMatrix", "int32_t"], + ["CSRMatrix", "int32_t"], + "int32_t" + ], + [ + ["CSRMatrix", "uint32_t"], + ["CSRMatrix", "uint32_t"], + "uint32_t" + ], + [ + ["CSRMatrix", "uint64_t"], + ["CSRMatrix", "uint64_t"], + "uint64_t" ] ], "opCodes": [ @@ -4537,6 +4631,18 @@ [ ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"] + ], + [ + ["CSRMatrix", "double"], + ["CSRMatrix", "double"] + ], + [ + ["CSRMatrix", "float"], + ["CSRMatrix", "float"] + ], + [ + ["CSRMatrix", "int64_t"], + ["CSRMatrix", "int64_t"] ] ], "opCodes": [ diff --git a/test/api/cli/codegen/SparsityLAOpsTest.cpp b/test/api/cli/codegen/SparsityLAOpsTest.cpp new file mode 100644 index 000000000..cc4d56796 --- /dev/null +++ b/test/api/cli/codegen/SparsityLAOpsTest.cpp @@ -0,0 +1,97 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +const std::string dirPath = "test/api/cli/codegen/"; + +TEST_CASE("ewUnary_abs, sparse", TAG_CODEGEN) { + std::string result = "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 1 1 0 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewunary_abs_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("ewBinary_add, sparse", TAG_CODEGEN) { + std::string result = "DenseMatrix(8x8, double)\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 2 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "CSRMatrix(8x8, double)\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 1\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 1 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewbinary_add_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("ewBinary_mul, sparse", TAG_CODEGEN) { + std::string result = "CSRMatrix(5x5, double)\n" + "0 0 0 0 2\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 2 2 0 0\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 1 1 0 0\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewbinary_mul_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("matmul, sparse-dense", TAG_CODEGEN) { + std::string result = "DenseMatrix(5x5, double)\n" + "1 1 1 1 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "2 2 2 2 2\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "1 0 0 1 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "matmul_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} diff --git a/test/api/cli/codegen/ewbinary_add_sparse.daphne b/test/api/cli/codegen/ewbinary_add_sparse.daphne new file mode 100644 index 000000000..ec1c7885f --- /dev/null +++ b/test/api/cli/codegen/ewbinary_add_sparse.daphne @@ -0,0 +1,6 @@ +W = rand(8, 8, 1.0, 1.0, 0.01, 1); +V = rand(8, 8, 1.0, 1.0, 0.01, 2); +X = rand(8, 8, 1.0, 1.0, 1, 3); + +print(W + X); // sparse + dense +print(W + V); // sparse + sparse \ No newline at end of file diff --git a/test/api/cli/codegen/ewbinary_mul_sparse.daphne b/test/api/cli/codegen/ewbinary_mul_sparse.daphne new file mode 100644 index 000000000..f4d369953 --- /dev/null +++ b/test/api/cli/codegen/ewbinary_mul_sparse.daphne @@ -0,0 +1,7 @@ +W = rand(5, 5, 1.0, 1.0, 0.1, 1); +V = rand(5, 5, 1.0, 1.0, 0.1, 2); +X = rand(5, 5, 1.0, 1.0, 1, 3); + +print(W * 2); // sparse * scalar +print(W * X); // sparse * dense +print(W * V); // sparse * sparse \ No newline at end of file diff --git a/test/api/cli/codegen/ewunary_abs_sparse.daph b/test/api/cli/codegen/ewunary_abs_sparse.daph new file mode 100644 index 000000000..0b11e40d7 --- /dev/null +++ b/test/api/cli/codegen/ewunary_abs_sparse.daph @@ -0,0 +1,2 @@ +W = rand(5, 5, -1.0, -1.0, 0.1, 1); +print(abs(W)); \ No newline at end of file diff --git a/test/api/cli/codegen/matmul_sparse.daphne b/test/api/cli/codegen/matmul_sparse.daphne new file mode 100644 index 000000000..3f3b19374 --- /dev/null +++ b/test/api/cli/codegen/matmul_sparse.daphne @@ -0,0 +1,6 @@ +W = rand(5, 5, 1.0, 1.0, 0.1, 1); +V = rand(5, 5, 1.0, 1.0, 0.1, 2); +X = rand(5, 5, 1.0, 1.0, 1, 3); + +print(W @ X); // sparse @ dense +print(W @ V); // sparse @ sparse diff --git a/test/api/cli/io/matrix_full.csv b/test/api/cli/io/matrix_full.csv new file mode 100644 index 000000000..f5f1b5258 --- /dev/null +++ b/test/api/cli/io/matrix_full.csv @@ -0,0 +1,3 @@ +1,2 +3,4 +5,6 diff --git a/test/api/cli/io/matrix_full.csv.meta b/test/api/cli/io/matrix_full.csv.meta new file mode 100644 index 000000000..24eb98984 --- /dev/null +++ b/test/api/cli/io/matrix_full.csv.meta @@ -0,0 +1 @@ +{"numCols":2,"numRows":3,"valueType":"si64"} \ No newline at end of file diff --git a/test/api/cli/io/matrix_view.csv b/test/api/cli/io/matrix_view.csv new file mode 100644 index 000000000..e2ba1efb1 --- /dev/null +++ b/test/api/cli/io/matrix_view.csv @@ -0,0 +1,3 @@ +2 +4 +6 diff --git a/test/api/cli/io/matrix_view.csv.meta b/test/api/cli/io/matrix_view.csv.meta new file mode 100644 index 000000000..0b9e3769a --- /dev/null +++ b/test/api/cli/io/matrix_view.csv.meta @@ -0,0 +1 @@ +{"numCols":1,"numRows":3,"valueType":"si64"} \ No newline at end of file