From 04fc40e11b9d0b903d9f59e5d51e506476783eb2 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Wed, 4 Dec 2024 22:37:38 +0100 Subject: [PATCH 01/33] build daphne-opt --- containers/entrypoint-interactive.sh | 6 +++--- containers/run-docker-example.sh | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/containers/entrypoint-interactive.sh b/containers/entrypoint-interactive.sh index 5d63841ca..960ea916b 100755 --- a/containers/entrypoint-interactive.sh +++ b/containers/entrypoint-interactive.sh @@ -13,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +USER=ubuntu /usr/sbin/sshd -f /etc/ssh/sshd_config -/usr/sbin/groupadd -g "$GID" dockerusers -/usr/sbin/useradd -c 'Docker Container User' -u $UID -g "$GID" -G sudo -m -s /bin/bash -d /home/"$USER" "$USER" +# /usr/sbin/groupadd -g "$GID" dockerusers +# /usr/sbin/useradd -c 'Docker Container User' -u $UID -g "$GID" -G sudo -m -s /bin/bash -d /home/"$USER" "$USER" printf "${USER} ALL=(ALL:ALL) NOPASSWD:ALL" | sudo EDITOR="tee -a" visudo #>> /dev/null mkdir -p /home/"$USER"/.ssh chmod 700 /home/"$USER"/.ssh diff --git a/containers/run-docker-example.sh b/containers/run-docker-example.sh index 633e11d33..f9ee48117 100755 --- a/containers/run-docker-example.sh +++ b/containers/run-docker-example.sh @@ -68,6 +68,7 @@ fi $USE_SUDO docker run $DEBUG_FLAGS $DEVICE_FLAGS -it --rm --hostname daphne-container -w $DAPHNE_ROOT_CONTAINER \ -v "$DAPHNE_ROOT:$DAPHNE_ROOT_CONTAINER" -e GID=$GID -e TERM=screen-256color -e PATH -e LD_LIBRARY_PATH \ -e USER=$USERNAME -e UID=$UID \ + --entrypoint /daphne/containers/entrypoint-interactive.sh \ "$DOCKER_IMAGE:$DOCKER_TAG" $command # move this up to above the DOCKER_IMAGE line to override the entrypoint: From 03cb04ada1ecb99a1bde1010d25e018ee63c2fbd Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sun, 22 Dec 2024 00:11:10 +0100 Subject: [PATCH 02/33] add SliceRowOpLowering --- scripts/examples/slice.daph | 17 ++ src/compiler/execution/DaphneIrExecutor.cpp | 3 + src/compiler/lowering/CMakeLists.txt | 2 + src/compiler/lowering/SliceRowOpLowering.cpp | 178 +++++++++++++++++++ src/ir/daphneir/Passes.h | 3 + src/ir/daphneir/Passes.td | 3 + 6 files changed, 206 insertions(+) create mode 100644 scripts/examples/slice.daph create mode 100644 src/compiler/lowering/SliceRowOpLowering.cpp diff --git a/scripts/examples/slice.daph b/scripts/examples/slice.daph new file mode 100644 index 000000000..0770b5913 --- /dev/null +++ b/scripts/examples/slice.daph @@ -0,0 +1,17 @@ +X = [1, 2, 3, 4, 5, 6, 7, 8, 9](3, 3); +//Y = transpose(X); +i=2;j=2;k=2; +n=1*3; + +outSliceRow = X[:2, :]; +print(outSliceRow); + + +//outExtractRow = X[ [0, 2], :]; +//outSliceCol = X[:, :2]; +//outExtractCol = X[:, [0, 2] ]; + +//print(Y); +//print(outExtractRow); +//print(outSliceCol); +//print(outExtractCol); \ No newline at end of file diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index e8f69c5d5..aca3c9260 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -263,6 +263,9 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createAggDimOpLoweringPass()); pm.addPass(mlir::daphne::createMapOpLoweringPass()); pm.addPass(mlir::daphne::createTransposeOpLoweringPass()); + + pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); + pm.addPass(mlir::createInlinerPass()); pm.addNestedPass(mlir::createLoopFusionPass()); diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 18939f92e..680f7d798 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -35,6 +35,8 @@ add_mlir_dialect_library(MLIRDaphneTransforms AggDimOpLowering.cpp TransposeOpLowering.cpp + SliceRowOpLowering.cpp + DEPENDS MLIRDaphneOpsIncGen MLIRDaphneTransformsIncGen diff --git a/src/compiler/lowering/SliceRowOpLowering.cpp b/src/compiler/lowering/SliceRowOpLowering.cpp new file mode 100644 index 000000000..60c70a8e6 --- /dev/null +++ b/src/compiler/lowering/SliceRowOpLowering.cpp @@ -0,0 +1,178 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +//template +class SliceRowOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + explicit SliceRowOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceRowOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(daphne::SliceRowOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceRowOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().getDefiningOp().getValue().dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().getDefiningOp().getValue().dyn_cast().getSInt(); + + Value resMemref = rewriter.create(loc, MemRefType::get({(upperExcl-lowerIncl), numCols}, matrixElementType)); + + DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({lowerIncl, 0}); + DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}); + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; + + SmallVector iterTypes{utils::IteratorType::parallel, + utils::IteratorType::parallel}; + + rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, + indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + OpBuilderNested.create(locNested, arg[0]); + }); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct SliceRowLoweringPass : public mlir::PassWrapper> { + explicit SliceRowLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice-row"; } + StringRef getDescription() const final { return "Lowers SliceRow operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceRowLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceRowOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index d90b7a9a7..f991e44c0 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -70,6 +70,9 @@ std::unique_ptr createSelectMatrixRepresentationsPass(const DaphneUserConf std::unique_ptr createSpecializeGenericFunctionsPass(const DaphneUserConfig &cfg); std::unique_ptr createTransposeOpLoweringPass(); std::unique_ptr createVectorizeComputationsPass(); + +std::unique_ptr createSliceRowOpLoweringPass(); + #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); #endif diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index 46fe5295a..7f1d32aa1 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -256,5 +256,8 @@ def LowerEwOpPass: Pass<"lower-ew", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createEwOpLoweringPass()"; } +def SliceRowOpLoweringPass: Pass<"lower-slice-row", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceRowOpLoweringPass()"; +} #endif // SRC_IR_DAPHNEIR_PASSES_TD From 537f655cb251c4807e9de411cebce9569a641098 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sun, 22 Dec 2024 23:06:44 +0100 Subject: [PATCH 03/33] add SliceColOpLowering --- src/compiler/execution/DaphneIrExecutor.cpp | 1 + src/compiler/lowering/CMakeLists.txt | 1 + src/compiler/lowering/SliceColOpLowering.cpp | 178 +++++++++++++++++++ src/ir/daphneir/Passes.h | 1 + src/ir/daphneir/Passes.td | 4 + 5 files changed, 185 insertions(+) create mode 100644 src/compiler/lowering/SliceColOpLowering.cpp diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index aca3c9260..604917862 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -265,6 +265,7 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createTransposeOpLoweringPass()); pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); + pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); pm.addPass(mlir::createInlinerPass()); diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 680f7d798..4ac8f8f38 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -36,6 +36,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms TransposeOpLowering.cpp SliceRowOpLowering.cpp + SliceColOpLowering.cpp DEPENDS MLIRDaphneOpsIncGen diff --git a/src/compiler/lowering/SliceColOpLowering.cpp b/src/compiler/lowering/SliceColOpLowering.cpp new file mode 100644 index 000000000..5d19ddcf3 --- /dev/null +++ b/src/compiler/lowering/SliceColOpLowering.cpp @@ -0,0 +1,178 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +//template +class SliceColOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + explicit SliceColOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceColOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(daphne::SliceColOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceColOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().getDefiningOp().getValue().dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().getDefiningOp().getValue().dyn_cast().getSInt(); + + Value resMemref = rewriter.create(loc, MemRefType::get({numRows, (upperExcl-lowerIncl)}, matrixElementType)); + + DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({0, lowerIncl}); + DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; + + SmallVector iterTypes{utils::IteratorType::parallel, + utils::IteratorType::parallel}; + + rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, + indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + OpBuilderNested.create(locNested, arg[0]); + }); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct SliceColLoweringPass : public mlir::PassWrapper> { + explicit SliceColLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice-col"; } + StringRef getDescription() const final { return "Lowers SliceCol operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceColLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceColOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index f991e44c0..848f3e4e6 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -72,6 +72,7 @@ std::unique_ptr createTransposeOpLoweringPass(); std::unique_ptr createVectorizeComputationsPass(); std::unique_ptr createSliceRowOpLoweringPass(); +std::unique_ptr createSliceColOpLoweringPass(); #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index 7f1d32aa1..56dcc1d1d 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -260,4 +260,8 @@ def SliceRowOpLoweringPass: Pass<"lower-slice-row", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createSliceRowOpLoweringPass()"; } +def SliceColOpLoweringPass: Pass<"lower-slice-col", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceColOpLoweringPass()"; +} + #endif // SRC_IR_DAPHNEIR_PASSES_TD From 2621af4621be68bbef628bbd42f8405a00e723aa Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 28 Dec 2024 01:12:19 +0100 Subject: [PATCH 04/33] exclude linalg.generic op, add memreftollvm pass --- src/compiler/execution/DaphneIrExecutor.cpp | 6 +++++ src/compiler/lowering/SliceRowOpLowering.cpp | 23 ++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 604917862..f732f675d 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -32,6 +32,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -186,6 +187,9 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { if (userConfig_.explain_kernels) pm.addPass(mlir::daphne::createPrintIRPass("IR after kernel lowering:")); + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); pm.addNestedPass(mlir::LLVM::createRequestCWrappersPass()); pm.addPass(mlir::daphne::createLowerToLLVMPass(userConfig_)); @@ -297,6 +301,8 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { mlir::LowerVectorToLLVMOptions lowerVectorToLLVMOptions; pm.addPass(mlir::createConvertVectorToLLVMPass(lowerVectorToLLVMOptions)); + + if (userConfig_.explain_mlir_codegen) pm.addPass(mlir::daphne::createPrintIRPass("IR after codegen pipeline")); } diff --git a/src/compiler/lowering/SliceRowOpLowering.cpp b/src/compiler/lowering/SliceRowOpLowering.cpp index 60c70a8e6..b831c43b0 100644 --- a/src/compiler/lowering/SliceRowOpLowering.cpp +++ b/src/compiler/lowering/SliceRowOpLowering.cpp @@ -94,25 +94,26 @@ class SliceRowOpLowering : public OpConversionPattern { auto lowerIncl = adaptor.getLowerIncl().getDefiningOp().getValue().dyn_cast().getSInt(); auto upperExcl = adaptor.getUpperExcl().getDefiningOp().getValue().dyn_cast().getSInt(); - Value resMemref = rewriter.create(loc, MemRefType::get({(upperExcl-lowerIncl), numCols}, matrixElementType)); + // Value resMemref = rewriter.create(loc, MemRefType::get({(upperExcl-lowerIncl), numCols}, matrixElementType)); DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({lowerIncl, 0}); DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}); DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); - Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + // Value selMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + Value resMemref = rewriter.create(loc, argMemref, offset, sizes, strides); - SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; + // SmallVector indexMaps{AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + // AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; - SmallVector iterTypes{utils::IteratorType::parallel, - utils::IteratorType::parallel}; + // SmallVector iterTypes{utils::IteratorType::parallel, + // utils::IteratorType::parallel}; - rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, - indexMaps, iterTypes, - [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { - OpBuilderNested.create(locNested, arg[0]); - }); + // rewriter.create(loc, TypeRange{}, ValueRange{selMemref}, ValueRange{resMemref}, + // indexMaps, iterTypes, + // [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + // OpBuilderNested.create(locNested, arg[0]); + // }); Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); From 3a9eafeca82e27fa98bcd624e1df619766428265 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 28 Dec 2024 01:18:16 +0100 Subject: [PATCH 05/33] Revert "build daphne-opt" This reverts commit 04fc40e11b9d0b903d9f59e5d51e506476783eb2. --- containers/entrypoint-interactive.sh | 6 +++--- containers/run-docker-example.sh | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/containers/entrypoint-interactive.sh b/containers/entrypoint-interactive.sh index 960ea916b..5d63841ca 100755 --- a/containers/entrypoint-interactive.sh +++ b/containers/entrypoint-interactive.sh @@ -13,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -USER=ubuntu + /usr/sbin/sshd -f /etc/ssh/sshd_config -# /usr/sbin/groupadd -g "$GID" dockerusers -# /usr/sbin/useradd -c 'Docker Container User' -u $UID -g "$GID" -G sudo -m -s /bin/bash -d /home/"$USER" "$USER" +/usr/sbin/groupadd -g "$GID" dockerusers +/usr/sbin/useradd -c 'Docker Container User' -u $UID -g "$GID" -G sudo -m -s /bin/bash -d /home/"$USER" "$USER" printf "${USER} ALL=(ALL:ALL) NOPASSWD:ALL" | sudo EDITOR="tee -a" visudo #>> /dev/null mkdir -p /home/"$USER"/.ssh chmod 700 /home/"$USER"/.ssh diff --git a/containers/run-docker-example.sh b/containers/run-docker-example.sh index f9ee48117..633e11d33 100755 --- a/containers/run-docker-example.sh +++ b/containers/run-docker-example.sh @@ -68,7 +68,6 @@ fi $USE_SUDO docker run $DEBUG_FLAGS $DEVICE_FLAGS -it --rm --hostname daphne-container -w $DAPHNE_ROOT_CONTAINER \ -v "$DAPHNE_ROOT:$DAPHNE_ROOT_CONTAINER" -e GID=$GID -e TERM=screen-256color -e PATH -e LD_LIBRARY_PATH \ -e USER=$USERNAME -e UID=$UID \ - --entrypoint /daphne/containers/entrypoint-interactive.sh \ "$DOCKER_IMAGE:$DOCKER_TAG" $command # move this up to above the DOCKER_IMAGE line to override the entrypoint: From 55d8a76624abf2f3b28eb8f0571e5ff5f4a311af Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 28 Dec 2024 01:25:01 +0100 Subject: [PATCH 06/33] remove improper test case --- scripts/examples/slice.daph | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 scripts/examples/slice.daph diff --git a/scripts/examples/slice.daph b/scripts/examples/slice.daph deleted file mode 100644 index 0770b5913..000000000 --- a/scripts/examples/slice.daph +++ /dev/null @@ -1,17 +0,0 @@ -X = [1, 2, 3, 4, 5, 6, 7, 8, 9](3, 3); -//Y = transpose(X); -i=2;j=2;k=2; -n=1*3; - -outSliceRow = X[:2, :]; -print(outSliceRow); - - -//outExtractRow = X[ [0, 2], :]; -//outSliceCol = X[:, :2]; -//outExtractCol = X[:, [0, 2] ]; - -//print(Y); -//print(outExtractRow); -//print(outSliceCol); -//print(outExtractCol); \ No newline at end of file From 204a5110142c9f69cbe3b13627f6eafa8a82c162 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 20 Jan 2025 11:26:27 +0100 Subject: [PATCH 07/33] combine slice row and column --- src/compiler/execution/DaphneIrExecutor.cpp | 5 +- src/compiler/lowering/CMakeLists.txt | 1 + src/compiler/lowering/SliceOpLowering.cpp | 186 ++++++++++++++++++++ src/ir/daphneir/Passes.h | 1 + src/ir/daphneir/Passes.td | 4 + 5 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 src/compiler/lowering/SliceOpLowering.cpp diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index f732f675d..34c0589b9 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -268,8 +268,9 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createMapOpLoweringPass()); pm.addPass(mlir::daphne::createTransposeOpLoweringPass()); - pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); - pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); + //pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); + //pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); + pm.addPass(mlir::daphne::createSliceOpLoweringPass()); pm.addPass(mlir::createInlinerPass()); diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 4ac8f8f38..36e8685b5 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -37,6 +37,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms SliceRowOpLowering.cpp SliceColOpLowering.cpp + SliceOpLowering.cpp DEPENDS MLIRDaphneOpsIncGen diff --git a/src/compiler/lowering/SliceOpLowering.cpp b/src/compiler/lowering/SliceOpLowering.cpp new file mode 100644 index 000000000..e92fd2b3c --- /dev/null +++ b/src/compiler/lowering/SliceOpLowering.cpp @@ -0,0 +1,186 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +static constexpr size_t ROW = 0; +static constexpr size_t COL = 1; + +template +class SliceOpLowering : public OpConversionPattern { + public: + //using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + explicit SliceOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("SliceOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(SliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().template dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "sliceOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + auto lowerIncl = adaptor.getLowerIncl().template getDefiningOp().getValue().template dyn_cast().getSInt(); + auto upperExcl = adaptor.getUpperExcl().template getDefiningOp().getValue().template dyn_cast().getSInt(); + + DenseI64ArrayAttr offset = sliceAlongDim == ROW ? rewriter.getDenseI64ArrayAttr({lowerIncl, 0}) + : rewriter.getDenseI64ArrayAttr({0, lowerIncl}); + + DenseI64ArrayAttr sizes = sliceAlongDim == ROW ? rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}) + : rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); + + // if (sliceAlongDim == ROW) + // { + // DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({lowerIncl, 0}); + // DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}); + // } + // else + // { + // DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({0, lowerIncl}); + // DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); + // } + + DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); + + Value resMemref = rewriter.create(loc, argMemref, offset, sizes, strides); + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +using SliceRowOpLowering = SliceOpLowering; +using SliceColOpLowering = SliceOpLowering; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct SliceLoweringPass : public mlir::PassWrapper> { + explicit SliceLoweringPass() {} + + StringRef getArgument() const final { return "lower-slice"; } + StringRef getDescription() const final { return "Lowers SliceRow/SliceCol operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void SliceLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createSliceOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index 848f3e4e6..1820c3ca8 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -73,6 +73,7 @@ std::unique_ptr createVectorizeComputationsPass(); std::unique_ptr createSliceRowOpLoweringPass(); std::unique_ptr createSliceColOpLoweringPass(); +std::unique_ptr createSliceOpLoweringPass(); #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index 56dcc1d1d..d1676dd95 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -264,4 +264,8 @@ def SliceColOpLoweringPass: Pass<"lower-slice-col", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createSliceColOpLoweringPass()"; } +def SliceOpLoweringPass: Pass<"lower-slice", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSliceOpLoweringPass()"; +} + #endif // SRC_IR_DAPHNEIR_PASSES_TD From f218f306b49d9a32c92276453c3e9c7345597c7f Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 20 Jan 2025 22:57:53 +0100 Subject: [PATCH 08/33] add ExtractOpLowering --- src/compiler/lowering/ExtractOpLowering.cpp | 205 ++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 src/compiler/lowering/ExtractOpLowering.cpp diff --git a/src/compiler/lowering/ExtractOpLowering.cpp b/src/compiler/lowering/ExtractOpLowering.cpp new file mode 100644 index 000000000..3d6dd5977 --- /dev/null +++ b/src/compiler/lowering/ExtractOpLowering.cpp @@ -0,0 +1,205 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/utils/LoweringUtils.h" +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + + +using namespace mlir; +using namespace std; + +static constexpr size_t ROW = 0; +static constexpr size_t COL = 1; + +template +class ExtractOpLowering : public OpConversionPattern { + public: + //using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + explicit ExtractOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { + this->setDebugName("ExtractOpLowering"); + } + + /** + * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * + * @return mlir::success if Transpose has been replaced, else mlir::failure. + */ + LogicalResult matchAndRewrite(ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + daphne::MatrixType matrixType = adaptor.getSource().getType().template dyn_cast(); + if (!matrixType) { + return failure(); + } + + Location loc = op->getLoc(); + + Type matrixElementType = matrixType.getElementType(); + ssize_t numRows = matrixType.getNumRows(); + ssize_t numCols = matrixType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + return rewriter.notifyMatchFailure( + op, "extractOp codegen currently only works with matrix dimensions that are known at compile time"); + } + + Value argMemref = rewriter.create( + loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); + + + daphne::MatrixType selectionType = adaptor.getSelectedRows().getType().template dyn_cast(); + if (!matrixType) { + return failure(); + } + + + + Type selectionElementType = selectionType.getElementType(); + ssize_t numSelectedRows = selectionType.getNumRows(); + + Value selectionMemref = rewriter.create( + loc, MemRefType::get({numSelectedRows, 1}, matrixElementType), adaptor.getSelectedRows()); + + Value resMemref = rewriter.create(loc, MemRefType::get({numSelectedRows, numCols}, matrixElementType)); + + for (ssize_t i = 0; i < numSelectedRows; i++) + { + + Value des = rewriter.create(loc, resMemref, + rewriter.getDenseI64ArrayAttr({i, 0}), + rewriter.getDenseI64ArrayAttr({1, numCols}), + rewriter.getDenseI64ArrayAttr({1, 1})); + + Value select = rewriter.create(loc, selectionMemref, + ValueRange{rewriter.create(loc, i), + rewriter.create(loc, 0)}); + + Value zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + + ValueRange offsets = {select, zero}; + ValueRange sizes = {rewriter.create(loc, 1), + rewriter.create(loc, numCols)}; + ValueRange strides = {rewriter.create(loc, 1), + rewriter.create(loc, 1)}; + + Value src = rewriter.create(loc, argMemref, offsets, sizes, strides); + + rewriter.create(loc, src, des); + + } + + + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + rewriter.replaceOp(op, resDenseMatrix); + + return success(); + } +}; + +using ExtractRowOpLowering = ExtractOpLowering; +//using ExtractColOpLowering = ExtractOpLowering; + +namespace { +/** + * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. + * + * This rewrite may enable loop fusion on the affine loops TransposeOp is + * lowered to by running the loop fusion pass. + */ +struct ExtractLoweringPass : public mlir::PassWrapper> { + explicit ExtractLoweringPass() {} + + StringRef getArgument() const final { return "lower-extract"; } + StringRef getDescription() const final { return "Lowers ExtractRow/ExtractCol operators to a Memref SubViewOp."; } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // end anonymous namespace + +void ExtractLoweringPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + target.addDynamicallyLegalOp([](Operation *op) { + Type operand = op->getOperand(0).getType(); + daphne::MatrixType matType = operand.dyn_cast(); + if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + return false; + } + return true; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr daphne::createExtractOpLoweringPass() { + return std::make_unique(); +} \ No newline at end of file From 99125d416fa5c4f20ae923943d462a35885b4693 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Tue, 21 Jan 2025 10:01:06 +0100 Subject: [PATCH 09/33] add unfinished ExtractOpLowering --- src/compiler/execution/DaphneIrExecutor.cpp | 1 + src/compiler/lowering/CMakeLists.txt | 1 + src/ir/daphneir/Passes.h | 1 + 3 files changed, 3 insertions(+) diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 34c0589b9..7d3a4fecb 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -271,6 +271,7 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { //pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); //pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); pm.addPass(mlir::daphne::createSliceOpLoweringPass()); + pm.addPass(mlir::daphne::createExtractOpLoweringPass()); pm.addPass(mlir::createInlinerPass()); diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 36e8685b5..d9386808e 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms SliceRowOpLowering.cpp SliceColOpLowering.cpp SliceOpLowering.cpp + ExtractOpLowering.cpp DEPENDS MLIRDaphneOpsIncGen diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index 1820c3ca8..ad11c5e67 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -74,6 +74,7 @@ std::unique_ptr createVectorizeComputationsPass(); std::unique_ptr createSliceRowOpLoweringPass(); std::unique_ptr createSliceColOpLoweringPass(); std::unique_ptr createSliceOpLoweringPass(); +std::unique_ptr createExtractOpLoweringPass(); #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); From f2345673c28fef57cacb6be09ef94dac9e2b16b1 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 25 Jan 2025 23:53:11 +0100 Subject: [PATCH 10/33] add EwUnaryOpsLowering for Sparse Matrix --- src/compiler/lowering/EwOpsLowering.cpp | 51 ++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 256795529..45fc307e6 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -77,6 +77,51 @@ template struct UnaryOpLowering : publi return mlir::success(); } + LogicalResult matchAndRewriteSparseMat(UnaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + + Location loc = op->getLoc(); + auto sparseMatType = adaptor.getArg().getType().template dyn_cast(); + Type matrixElementType = sparseMatType.getElementType(); + ssize_t numRows = sparseMatType.getNumRows(); + ssize_t numCols = sparseMatType.getNumCols(); + + if (numRows < 0 || numCols < 0) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + } + + MemRefType sparseValuesMemRefType = + //MemRefType::get({ShapedType::kDynamic}, matrixElementType); + MemRefType::get({numRows*numCols}, matrixElementType); + + Value argValuesMemref = rewriter.create( + loc, sparseValuesMemRefType, adaptor.getArg()); + + Value resMemref = rewriter.create( + loc, sparseValuesMemRefType); + + SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(1, rewriter.getContext())}; + SmallVector iterTypes = {utils::IteratorType::parallel}; + + rewriter.create( + loc, TypeRange{}, ValueRange{argValuesMemref}, ValueRange{resMemref}, indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + Value resValue = unaryFunc(OpBuilderNested, locNested, this->typeConverter, arg[0]); + OpBuilderNested.create(locNested, resValue); + }); + + + rewriter.replaceOp(op, resMemref); + + //auto resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + + //rewriter.replaceOp(op, resDenseMatrix); + + return mlir::success(); + } + LogicalResult matchAndRewrite(UnaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); @@ -87,6 +132,10 @@ template struct UnaryOpLowering : publi return matchAndRewriteScalarVal(op, adaptor, rewriter); } + if (matrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) { + return matchAndRewriteSparseMat(op, adaptor, rewriter); + } + Type matrixElementType = matrixType.getElementType(); ssize_t numRows = matrixType.getNumRows(); ssize_t numCols = matrixType.getNumCols(); @@ -541,7 +590,7 @@ void EwOpLoweringPass::runOnOperation() { return false; } auto matType = operand.dyn_cast(); - if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { + if (matType && (matType.getRepresentation() == daphne::MatrixRepresentation::Dense || matType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { return false; } return true; From e68919a20f7f3e61cbabc935454a9b0a46738619 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 27 Jan 2025 22:26:55 +0100 Subject: [PATCH 11/33] comment out extract op lowering --- src/compiler/execution/DaphneIrExecutor.cpp | 2 +- test/api/cli/io/matrix_full.csv | 3 +++ test/api/cli/io/matrix_full.csv.meta | 1 + test/api/cli/io/matrix_view.csv | 3 +++ test/api/cli/io/matrix_view.csv.meta | 1 + 5 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 test/api/cli/io/matrix_full.csv create mode 100644 test/api/cli/io/matrix_full.csv.meta create mode 100644 test/api/cli/io/matrix_view.csv create mode 100644 test/api/cli/io/matrix_view.csv.meta diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 0123923cc..d35e78d1b 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -279,7 +279,7 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { //pm.addPass(mlir::daphne::createSliceRowOpLoweringPass()); //pm.addPass(mlir::daphne::createSliceColOpLoweringPass()); pm.addPass(mlir::daphne::createSliceOpLoweringPass()); - pm.addPass(mlir::daphne::createExtractOpLoweringPass()); + //pm.addPass(mlir::daphne::createExtractOpLoweringPass()); pm.addPass(mlir::createInlinerPass()); diff --git a/test/api/cli/io/matrix_full.csv b/test/api/cli/io/matrix_full.csv new file mode 100644 index 000000000..f5f1b5258 --- /dev/null +++ b/test/api/cli/io/matrix_full.csv @@ -0,0 +1,3 @@ +1,2 +3,4 +5,6 diff --git a/test/api/cli/io/matrix_full.csv.meta b/test/api/cli/io/matrix_full.csv.meta new file mode 100644 index 000000000..24eb98984 --- /dev/null +++ b/test/api/cli/io/matrix_full.csv.meta @@ -0,0 +1 @@ +{"numCols":2,"numRows":3,"valueType":"si64"} \ No newline at end of file diff --git a/test/api/cli/io/matrix_view.csv b/test/api/cli/io/matrix_view.csv new file mode 100644 index 000000000..e2ba1efb1 --- /dev/null +++ b/test/api/cli/io/matrix_view.csv @@ -0,0 +1,3 @@ +2 +4 +6 diff --git a/test/api/cli/io/matrix_view.csv.meta b/test/api/cli/io/matrix_view.csv.meta new file mode 100644 index 000000000..0b9e3769a --- /dev/null +++ b/test/api/cli/io/matrix_view.csv.meta @@ -0,0 +1 @@ +{"numCols":1,"numRows":3,"valueType":"si64"} \ No newline at end of file From 7cb9c3f1c2f3392ccef711b9d24fd8a10e277ff6 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Tue, 28 Jan 2025 10:28:54 +0100 Subject: [PATCH 12/33] add untested kernel for converting memref to CSR --- src/compiler/utils/LoweringUtils.cpp | 18 +++++++++ .../local/kernels/ConvertMemRefToCSRMatrix.h | 38 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h diff --git a/src/compiler/utils/LoweringUtils.cpp b/src/compiler/utils/LoweringUtils.cpp index 74fba85ab..68fe613be 100644 --- a/src/compiler/utils/LoweringUtils.cpp +++ b/src/compiler/utils/LoweringUtils.cpp @@ -85,6 +85,24 @@ mlir::Value convertMemRefToDenseMatrix(mlir::Location loc, mlir::ConversionPatte strides[0], strides[1]); } +mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Value valuesMemRef, mlir::Value colIdxsMemRef, mlir::Value rowOffsetsMemRef, + size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, mlir::Type type) +{ + //auto extractStridedMetadataOp = rewriter.create(loc, memRef); + // aligned ptr (memref.data) + mlir::Value alignedValuesPtr = rewriter.create + (loc, valuesMemRef); + mlir::Value alignedColIdxsPtr = rewriter.create + (loc, colIdxsMemRef); + mlir::Value alignedRowOffsetsPtr = rewriter.create + (loc, rowOffsetsMemRef); + + return rewriter.create(loc, type, + alignedValuesPtr, alignedColIdxsPtr, alignedRowOffsetsPtr, + size_t maxNumRows, size_t numCols, size_t maxNumNonZeros); +} + mlir::Type convertFloat(mlir::FloatType floatType) { return mlir::IntegerType::get(floatType.getContext(), floatType.getIntOrFloatBitWidth()); } diff --git a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h new file mode 100644 index 000000000..fa9b4c539 --- /dev/null +++ b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h @@ -0,0 +1,38 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline void convertMemRefToCSRMatrix(CSRMatrix *&result, + size_t baseValuesPtr, size_t baseColIdxsPtr, size_t baseRowOffsetsPtr, + size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, DCTX(ctx)) +{ + auto no_op_deleter = [](T *) {}; + T *valuePtr = reinterpret_cast(baseValuesPtr); + std::shared_ptr ptrValues(valuePtr, no_op_deleter); + std::shared_ptr ptrColIdxs(baseColIdxsPtr, no_op_deleter); + std::shared_ptr ptrRowOffsets(baseRowOffsetsPtr, no_op_deleter); + result = DataObjectFactory::create>(maxNumRows, numCols, maxNumNonZeros, false); + + result.getValuesSharedPtr() = ptrValues; + result.getColIdxsSharedPtr() = ptrColIdxs; + result.getRowOffsetsSharedPtr() = ptrRowOffsets; + +} From edc5546404cd972d94ad8110771d47d7e02e5b36 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Thu, 30 Jan 2025 21:47:46 +0100 Subject: [PATCH 13/33] add EwUnaryMat for CSR Matrix --- src/compiler/lowering/EwOpsLowering.cpp | 27 ++++++-- src/compiler/utils/LoweringUtils.cpp | 4 +- src/compiler/utils/LoweringUtils.h | 4 ++ src/ir/daphneir/DaphneOps.td | 9 +++ .../local/kernels/ConvertMemRefToCSRMatrix.h | 17 ++--- src/runtime/local/kernels/EwUnaryMat.h | 29 +++++++++ src/runtime/local/kernels/kernels.json | 62 +++++++++++++++++++ 7 files changed, 139 insertions(+), 13 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 45fc307e6..07f897251 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -93,7 +93,7 @@ template struct UnaryOpLowering : publi MemRefType sparseValuesMemRefType = //MemRefType::get({ShapedType::kDynamic}, matrixElementType); - MemRefType::get({numRows*numCols}, matrixElementType); + MemRefType::get({ShapedType::kDynamic}, matrixElementType); Value argValuesMemref = rewriter.create( loc, sparseValuesMemRefType, adaptor.getArg()); @@ -113,11 +113,25 @@ template struct UnaryOpLowering : publi }); - rewriter.replaceOp(op, resMemref); + //rewriter.replaceOp(op, resMemref); + MemRefType sparseColIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseRowOffsetsMemRefType = MemRefType::get({numRows + 1}, rewriter.getIndexType()); + + Value argColIdxsMemref = rewriter.create( + loc, sparseColIdxsMemRefType, adaptor.getArg()); + Value argRowOffsetsMemref = rewriter.create( + loc, sparseRowOffsetsMemRefType, adaptor.getArg()); + + Value maxNumRowsValue = rewriter.create(loc, numRows); + Value numColsValue = rewriter.create(loc, numCols); + Value maxNumNonZerosValue = rewriter.create(loc, numCols * numRows); + //auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, resMemref, op.getType()); - //auto resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); + auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resMemref, argColIdxsMemref, argRowOffsetsMemref, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); - //rewriter.replaceOp(op, resDenseMatrix); + rewriter.replaceOp(op, resCSRMatrix); return mlir::success(); } @@ -265,6 +279,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { Type matrixElementType = lhsMatrixType.getElementType(); + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + { + MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); + auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); + } MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); diff --git a/src/compiler/utils/LoweringUtils.cpp b/src/compiler/utils/LoweringUtils.cpp index 68fe613be..089c58ac6 100644 --- a/src/compiler/utils/LoweringUtils.cpp +++ b/src/compiler/utils/LoweringUtils.cpp @@ -87,7 +87,7 @@ mlir::Value convertMemRefToDenseMatrix(mlir::Location loc, mlir::ConversionPatte mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, mlir::Value valuesMemRef, mlir::Value colIdxsMemRef, mlir::Value rowOffsetsMemRef, - size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, mlir::Type type) + mlir::Value maxNumRows, mlir::Value numCols, mlir::Value maxNumNonZeros, mlir::Type type) { //auto extractStridedMetadataOp = rewriter.create(loc, memRef); // aligned ptr (memref.data) @@ -100,7 +100,7 @@ mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPattern return rewriter.create(loc, type, alignedValuesPtr, alignedColIdxsPtr, alignedRowOffsetsPtr, - size_t maxNumRows, size_t numCols, size_t maxNumNonZeros); + maxNumRows, numCols, maxNumNonZeros); } mlir::Type convertFloat(mlir::FloatType floatType) { diff --git a/src/compiler/utils/LoweringUtils.h b/src/compiler/utils/LoweringUtils.h index a723492f9..80137a264 100644 --- a/src/compiler/utils/LoweringUtils.h +++ b/src/compiler/utils/LoweringUtils.h @@ -43,6 +43,10 @@ void affineFillMemRef(mlir::Value value, mlir::ConversionPatternRewriter &rewrit mlir::Value convertMemRefToDenseMatrix(mlir::Location, mlir::ConversionPatternRewriter &, mlir::Value memRef, mlir::Type); +mlir::Value convertMemRefToCSRMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Value valuesMemRef, mlir::Value colIdxsMemRef, mlir::Value rowOffsetsMemRef, + mlir::Value maxNumRows, mlir::Value numCols, mlir::Value maxNumNonZeros, mlir::Type type); + llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc); diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index b2d91310b..94b6557af 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -83,6 +83,15 @@ def Daphne_ConvertMemRefToDenseMatrix : Daphne_Op<"convertMemRefToDenseMatrix"> let results = (outs MatrixOrU:$res); } +def Daphne_ConvertMemRefToCSRMatrix : Daphne_Op<"convertMemRefToCSRMatrix"> { + let summary = "Return a CSRMatrix."; + let description = [{ Constructs a DenseMatrix given 3 rank 2 StridedMemRefType. }]; + + /* let arguments = (ins AnyMemRef:$arg); */ + let arguments = (ins Size:$baseValues, Size:$baseColIdxs, Size:$baseRowOffsets, Size:$maxNumRows, Size:$numCols, Size:$maxNumNonZeros); + let results = (outs MatrixOrU:$res); +} + def Daphne_ConvertDenseMatrixToMemRef : Daphne_Op<"convertDenseMatrixToMemRef", [Pure]> { let summary = "Given a DenseMatrix, return a StridedMemRefType."; let description = [{ Constructs a StridedMemRefType with rank 2 from a DenseMatrix* with already allocated memory. }]; diff --git a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h index fa9b4c539..0cd8de8d2 100644 --- a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h +++ b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h @@ -24,15 +24,18 @@ inline void convertMemRefToCSRMatrix(CSRMatrix *&result, size_t baseValuesPtr, size_t baseColIdxsPtr, size_t baseRowOffsetsPtr, size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, DCTX(ctx)) { - auto no_op_deleter = [](T *) {}; + auto no_op_deleter_1 = [](T *) {}; + auto no_op_deleter_2 = [](long unsigned int *) {}; T *valuePtr = reinterpret_cast(baseValuesPtr); - std::shared_ptr ptrValues(valuePtr, no_op_deleter); - std::shared_ptr ptrColIdxs(baseColIdxsPtr, no_op_deleter); - std::shared_ptr ptrRowOffsets(baseRowOffsetsPtr, no_op_deleter); + long unsigned int *colIdxsPtr = reinterpret_cast(baseColIdxsPtr); + long unsigned int *rowOffsetsPtr = reinterpret_cast(baseRowOffsetsPtr); + std::shared_ptr ptrValues(valuePtr, no_op_deleter_1); + std::shared_ptr ptrColIdxs(colIdxsPtr, no_op_deleter_2); + std::shared_ptr ptrRowOffsets(rowOffsetsPtr, no_op_deleter_2); result = DataObjectFactory::create>(maxNumRows, numCols, maxNumNonZeros, false); - result.getValuesSharedPtr() = ptrValues; - result.getColIdxsSharedPtr() = ptrColIdxs; - result.getRowOffsetsSharedPtr() = ptrRowOffsets; + result->getValuesSharedPtr() = ptrValues; + result->getColIdxsSharedPtr() = ptrColIdxs; + result->getRowOffsetsSharedPtr() = ptrRowOffsets; } diff --git a/src/runtime/local/kernels/EwUnaryMat.h b/src/runtime/local/kernels/EwUnaryMat.h index 1524587cb..d438a2ba5 100644 --- a/src/runtime/local/kernels/EwUnaryMat.h +++ b/src/runtime/local/kernels/EwUnaryMat.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -94,4 +95,32 @@ template struct EwUnaryMat, Matrix> { } }; +// ---------------------------------------------------------------------------- +// CSRMatrix <- CSRMatrix +// ---------------------------------------------------------------------------- + +template struct EwUnaryMat, CSRMatrix> { + static void apply(UnaryOpCode opCode, CSRMatrix *&res, const CSRMatrix *arg, DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + const size_t numCols = arg->getNumCols(); + const size_t maxNumNonZeros = arg->getMaxNumNonZeros(); + const size_t numNonZeros = arg->getNumNonZeros(); + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, maxNumNonZeros, false); + + const VT *valuesArg = arg->getValues(); + VT *valuesRes = res->getValues(); + + res->getColIdxsSharedPtr() = arg->getColIdxsSharedPtr(); + res->getRowOffsetsSharedPtr() = arg->getRowOffsetsSharedPtr(); + + EwUnaryScaFuncPtr func = getEwUnaryScaFuncPtr(opCode); + + for (size_t i = 0; i < numNonZeros; i++) + valuesRes[i] = func(valuesArg[i], ctx); + + } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_EWUNARYMAT_H diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 19597655e..a7526c0b3 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -1511,6 +1511,60 @@ ["double"] ] }, + { + "kernelTemplate": { + "header": "ConvertMemRefToCSRMatrix.h", + "opName": "convertMemRefToCSRMatrix", + "returnType": "void", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *&", + "name": "result" + }, + { + "type": "size_t", + "name": "baseValuesPtr" + }, + { + "type": "size_t", + "name": "baseColIdxsPtr" + }, + { + "type": "size_t", + "name": "baseRowOffsetsPtr" + }, + { + "type": "size_t", + "name": "maxNumRows" + }, + { + "type": "size_t", + "name": "numCols" + }, + { + "type": "size_t", + "name": "maxNumNonZeros" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["size_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, { "kernelTemplate": { "header": "ConvertDenseMatrixToMemRef.h", @@ -4537,6 +4591,14 @@ [ ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"] + ], + [ + ["CSRMatrix", "double"], + ["CSRMatrix", "double"] + ], + [ + ["CSRMatrix", "int64_t"], + ["CSRMatrix", "int64_t"] ] ], "opCodes": [ From 196b8e120a68d60d5d7a38134c449a429627f99a Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 3 Feb 2025 17:43:04 +0100 Subject: [PATCH 14/33] add a new constructor for CSRMatrix --- src/runtime/local/datastructures/CSRMatrix.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/runtime/local/datastructures/CSRMatrix.h b/src/runtime/local/datastructures/CSRMatrix.h index 0298622bb..9c00e3499 100644 --- a/src/runtime/local/datastructures/CSRMatrix.h +++ b/src/runtime/local/datastructures/CSRMatrix.h @@ -129,6 +129,16 @@ template class CSRMatrix : public Matrix { rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); } + CSRMatrix(size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, + std::shared_ptr &values, std::shared_ptr &colIdxs, std::shared_ptr &rowOffsets) + : Matrix(maxNumRows, numCols), numRowsAllocated(maxNumRows), isRowAllocatedBefore(false), + maxNumNonZeros(maxNumNonZeros), lastAppendedRowIdx(0) { + + this->values = values; + this->colIdxs = colIdxs; + this->rowOffsets = rowOffsets; + } + virtual ~CSRMatrix() { // nothing to do } From 6212dc6f00ad6c178491e2d49a2ff30a4a497b1e Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 3 Feb 2025 17:45:55 +0100 Subject: [PATCH 15/33] adapt MemRefToCSR kernel with new constructor --- .../local/kernels/ConvertMemRefToCSRMatrix.h | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h index 0cd8de8d2..4cce7d5d5 100644 --- a/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h +++ b/src/runtime/local/kernels/ConvertMemRefToCSRMatrix.h @@ -25,17 +25,13 @@ inline void convertMemRefToCSRMatrix(CSRMatrix *&result, size_t maxNumRows, size_t numCols, size_t maxNumNonZeros, DCTX(ctx)) { auto no_op_deleter_1 = [](T *) {}; - auto no_op_deleter_2 = [](long unsigned int *) {}; + auto no_op_deleter_2 = [](size_t *) {}; T *valuePtr = reinterpret_cast(baseValuesPtr); - long unsigned int *colIdxsPtr = reinterpret_cast(baseColIdxsPtr); - long unsigned int *rowOffsetsPtr = reinterpret_cast(baseRowOffsetsPtr); + size_t *colIdxsPtr = reinterpret_cast(baseColIdxsPtr); + size_t *rowOffsetsPtr = reinterpret_cast(baseRowOffsetsPtr); std::shared_ptr ptrValues(valuePtr, no_op_deleter_1); - std::shared_ptr ptrColIdxs(colIdxsPtr, no_op_deleter_2); - std::shared_ptr ptrRowOffsets(rowOffsetsPtr, no_op_deleter_2); - result = DataObjectFactory::create>(maxNumRows, numCols, maxNumNonZeros, false); - - result->getValuesSharedPtr() = ptrValues; - result->getColIdxsSharedPtr() = ptrColIdxs; - result->getRowOffsetsSharedPtr() = ptrRowOffsets; - + std::shared_ptr ptrColIdxs(colIdxsPtr, no_op_deleter_2); + std::shared_ptr ptrRowOffsets(rowOffsetsPtr, no_op_deleter_2); + result = DataObjectFactory::create>( + maxNumRows, numCols, maxNumNonZeros, ptrValues, ptrColIdxs, ptrRowOffsets); } From 5679eeaeba76adf31595dba04bb0f59dd07f7ea9 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 3 Feb 2025 17:49:02 +0100 Subject: [PATCH 16/33] add CSR support to EwBinaryObjSca --- src/runtime/local/kernels/EwBinaryObjSca.h | 42 ++++++++++++++++++++++ src/runtime/local/kernels/kernels.json | 34 ++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/runtime/local/kernels/EwBinaryObjSca.h b/src/runtime/local/kernels/EwBinaryObjSca.h index 62d704ded..f83837195 100644 --- a/src/runtime/local/kernels/EwBinaryObjSca.h +++ b/src/runtime/local/kernels/EwBinaryObjSca.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -154,4 +155,45 @@ template struct EwBinaryObjSca { } }; +// ---------------------------------------------------------------------------- +// CSRMatrix <- CSRMatrix, scalar +// ---------------------------------------------------------------------------- + +template +struct EwBinaryObjSca, CSRMatrix, VTRhs> { + static void apply(BinaryOpCode opCode, CSRMatrix *&res, const CSRMatrix *lhs, VTRhs rhs, + DCTX(ctx)) { + + if (opCode != BinaryOpCode::MUL) + throw std::runtime_error("EwBinaryObjSca::apply: only support MUL for CSR Matrix"); + + const size_t numRows = lhs->getNumRows(); + const size_t numCols = lhs->getNumCols(); + const size_t maxNumNonZeros = lhs->getMaxNumNonZeros(); + const size_t numNonZeros = lhs->getNumNonZeros(); + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, maxNumNonZeros, false); + + const VTLhs *valuesLhs = lhs->getValues(); + const size_t *colIdxsLhs = lhs->getColIdxs(); + const size_t *rowOffsetsLhs = lhs->getRowOffsets(); + VTRes *valuesRes = res->getValues(); + size_t *colIdxsRes = res->getColIdxs(); + size_t *rowOffsetsRes = res->getRowOffsets(); + + for (size_t i = 0; i < numNonZeros; i++) + colIdxsRes[i] = colIdxsLhs[i]; + + for (size_t i = 0; i < numRows + 1; i++) + rowOffsetsRes[i] = rowOffsetsLhs[i]; + + EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); + + for (size_t i = 0; i < numNonZeros; i++) + valuesRes[i] = func(valuesLhs[i], rhs, ctx); + + } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_EWBINARYOBJSCA_H diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index a7526c0b3..0847b4052 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -2218,6 +2218,36 @@ ["DenseMatrix", "uint64_t"], ["DenseMatrix", "uint64_t"], "uint64_t" + ], + [ + ["CSRMatrix", "float"], + ["CSRMatrix", "float"], + "float" + ], + [ + ["CSRMatrix", "double"], + ["CSRMatrix", "double"], + "double" + ], + [ + ["CSRMatrix", "int64_t"], + ["CSRMatrix", "int64_t"], + "int64_t" + ], + [ + ["CSRMatrix", "int32_t"], + ["CSRMatrix", "int32_t"], + "int32_t" + ], + [ + ["CSRMatrix", "uint32_t"], + ["CSRMatrix", "uint32_t"], + "uint32_t" + ], + [ + ["CSRMatrix", "uint64_t"], + ["CSRMatrix", "uint64_t"], + "uint64_t" ] ], "opCodes": [ @@ -4596,6 +4626,10 @@ ["CSRMatrix", "double"], ["CSRMatrix", "double"] ], + [ + ["CSRMatrix", "float"], + ["CSRMatrix", "float"] + ], [ ["CSRMatrix", "int64_t"], ["CSRMatrix", "int64_t"] From a2f9c906f172da244a67f26541e679c72dbafb35 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Thu, 6 Feb 2025 12:58:48 +0100 Subject: [PATCH 17/33] add EwOpsLowering for Op between CSR and Dense --- src/compiler/lowering/EwOpsLowering.cpp | 209 +++++++++++++++++++++++- 1 file changed, 203 insertions(+), 6 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 07f897251..bdd8e6af3 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" @@ -46,6 +47,7 @@ #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +using namespace std; // **************************************************************************** // Rewriter Templates (Elemwise Unary, Elemwise Binary) @@ -86,6 +88,7 @@ template struct UnaryOpLowering : publi ssize_t numCols = sparseMatType.getNumCols(); if (numRows < 0 || numCols < 0) { + std::cout<<"here 5"< struct UnaryOpLowering : publi Value argValuesMemref = rewriter.create( loc, sparseValuesMemRefType, adaptor.getArg()); + Value one = rewriter.create(loc, 1); Value resMemref = rewriter.create( - loc, sparseValuesMemRefType); + loc, sparseValuesMemRefType, ValueRange{one}); SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()), AffineMap::getMultiDimIdentityMap(1, rewriter.getContext())}; @@ -130,6 +134,7 @@ template struct UnaryOpLowering : publi auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, resMemref, argColIdxsMemref, argRowOffsetsMemref, maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + //maxNumRowsValue, numColsValue, maxNumNonZerosValue, adaptor.getArg().getType()); rewriter.replaceOp(op, resCSRMatrix); @@ -155,6 +160,7 @@ template struct UnaryOpLowering : publi ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { + std::cout<<"here 6"< { if (lhsRows != 1 && rhsRows == 1) { // rhs is a row vector, broadcast along columns if (lhsCols != rhsCols) { + std::cout<<"here 7"< { } else if (lhsCols != 1 && rhsCols == 1) { // rhs is a column vector, broadcast along rows if (lhsRows != rhsRows) { + std::cout<<"here 8"< { } else { // rhs is not broadcasted, return identity mapping if (lhsRows != rhsRows || lhsCols != rhsCols) { + std::cout<<"here 9"< { Type matrixElementType = lhsMatrixType.getElementType(); - if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) { - MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); - auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); + MemRefType valuesMemRefType = MemRefType::get({ShapedType::kDynamic}, matrixElementType); + MemRefType colIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType rowOffsetsMemRefType = MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemref = rewriter.create(loc, valuesMemRefType, lhs); + auto lhsColIdxsMemref = rewriter.create(loc, colIdxsMemRefType, lhs); + auto lhsRowOffsetsMemref = rewriter.create(loc, rowOffsetsMemRefType, lhs); + + Value one = rewriter.create(loc, 1); + Value resMemref = rewriter.create(loc, valuesMemRefType, ValueRange{one}); + + SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(1, rewriter.getContext())}; + SmallVector iterTypes = {utils::IteratorType::parallel}; + + rewriter.create( + loc, TypeRange{}, ValueRange{lhsValuesMemref}, ValueRange{resMemref}, indexMaps, iterTypes, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange arg) { + Value resValue = binaryFunc(OpBuilderNested, locNested, this->typeConverter, arg[0], rhs); + OpBuilderNested.create(locNested, resValue); + }); + + Value maxNumRowsValue = rewriter.create(loc, lhsRows); + Value numColsValue = rewriter.create(loc, lhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, lhsCols * lhsRows); + + auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resMemref, lhsColIdxsMemref, lhsRowOffsetsMemref, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } + MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); @@ -306,6 +348,145 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { return mlir::success(); } + LogicalResult matchAndRewriteSparseDenseMat(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + auto sparseLhsMatrixType = lhs.getType().template dyn_cast(); + auto denseRhsMatrixType = rhs.getType().template dyn_cast(); + + ssize_t sparseLhsRows = sparseLhsMatrixType.getNumRows(); + ssize_t sparseLhsCols = sparseLhsMatrixType.getNumCols(); + ssize_t denseRhsRows = denseRhsMatrixType.getNumRows(); + ssize_t denseRhsCols = denseRhsMatrixType.getNumCols(); + + MemRefType sparseLhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, sparseLhsMatrixType.getElementType()); + MemRefType sparseLhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseLhsRowOffsetsMemRefType = + MemRefType::get({sparseLhsRows + 1}, rewriter.getIndexType()); + MemRefType denseRhsMemRefType = + MemRefType::get({denseRhsRows, denseRhsCols}, denseRhsMatrixType.getElementType()); + + auto sparseLhsValuesMemRef = + rewriter.create(loc, sparseLhsValuesMemRefType, lhs); + auto sparseLhsColIdxsMemRef = + rewriter.create(loc, sparseLhsColIdxsMemRefType, lhs); + auto sparseLhsRowOffsetsMemRef = + rewriter.create(loc, sparseLhsRowOffsetsMemRefType, lhs); + auto denseRhsMemRef = + rewriter.create(loc, denseRhsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numSparseLhsRowsValue = rewriter.create(loc, sparseLhsRows); + + auto resDenseMemRef = rewriter.create(loc, denseRhsMemRefType); + rewriter.create(loc, denseRhsMemRef, resDenseMemRef); + auto resSparseMemRef = rewriter.create(loc, sparseLhsValuesMemRefType, ValueRange{one}); + + rewriter.create( + // loc, rowPtr, nextRowPtr, rewriter.create(loc, 1), + loc, zero, numSparseLhsRowsValue, one, ValueRange{}, + // [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx) + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopInvariants) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + + auto colIdxLowerIncl = OpBuilderNested.create( + locNested, sparseLhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto colIdxUpperExcl = OpBuilderNested.create( + locNested, sparseLhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + OpBuilderNested.create( + // locNested, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{rowPtr}, + locNested, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{}, + // [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, ValueRange loopInvariantsNested) + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, ValueRange loopInvariants) + { + // auto rowIdx = loopInvariantsNested[0]; + auto rowIdx = rowPtr; + auto colIdx = OpBuilderTwiceNested.create( + locTwiceNested, sparseLhsColIdxsMemRef, ValueRange{loopIdxNested}); + + auto sparseLhsValue = OpBuilderTwiceNested.create( + locTwiceNested, sparseLhsValuesMemRef, ValueRange{loopIdxNested}); + + auto denseRhsValue = OpBuilderTwiceNested.create( + locTwiceNested, denseRhsMemRef, ValueRange{rowIdx, colIdx}); + + Value resValue = binaryFunc( + OpBuilderTwiceNested, locTwiceNested, this->typeConverter, sparseLhsValue, denseRhsValue); + + //Value store; + + if (llvm::isa(op)) + { + // auto store = OpBuilderTwiceNested.create( + OpBuilderTwiceNested.create( + locTwiceNested, resValue, resDenseMemRef, ValueRange{rowIdx, colIdx}); + } + else if (llvm::isa(op)) + { + // auto store = OpBuilderTwiceNested.create( + OpBuilderTwiceNested.create( + locTwiceNested, resValue, resSparseMemRef, ValueRange{loopIdxNested}); + } + else + { + std::cout<<"here 10"<(locTwiceNested, resValue); + // OpBuilderTwiceNested.create(locTwiceNested); + } + ); + + // OpBuilderNested.create(locNested, resValue); + // auto resValue = colLoop.getResult(0); + OpBuilderNested.create(locNested); + } + ); + + if (llvm::isa(op)) + { + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resDenseMemRef, op.getType()); + std::cout<<"here 1"<(op)) + { + llvm::errs()<(loc, sparseLhsRows); + Value numColsValue = rewriter.create(loc, sparseLhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, sparseLhsCols * sparseLhsRows); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resSparseMemRef, sparseLhsColIdxsMemRef, sparseLhsRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + std::cout<<"here 2"<getLoc(); Value lhs = adaptor.getLhs(); @@ -323,6 +504,10 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhs); } + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + Type matrixElementType = lhsMatrixType.getElementType(); ssize_t lhsRows = lhsMatrixType.getNumRows(); @@ -331,6 +516,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ssize_t rhsCols = rhsMatrixType.getNumCols(); if (lhsRows < 0 || lhsCols < 0 || rhsRows < 0 || rhsCols < 0) { + std::cout<<"here 4"<(); + daphne::DaphneDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect>(); } void runOnOperation() final; @@ -596,7 +782,7 @@ void EwOpLoweringPass::runOnOperation() { target.addLegalDialect(); + mlir::math::MathDialect, mlir::linalg::LinalgDialect, mlir::scf::SCFDialect>(); // UnaryOps target.addDynamicallyLegalOp(rhs) || llvm::isa(rhs)) && + (lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { + return false; + } + + if ((lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse) && + (rhsMatType && rhsMatType.getRepresentation() == daphne::MatrixRepresentation::Dense)) { + return false; + } + return true; }); From 90dabd6e8e39d6d6fb013d47404d2787b57656f6 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 8 Feb 2025 00:03:55 +0100 Subject: [PATCH 18/33] add EwBinaryMat kernel for Dense <- (CSR + Dense) --- src/runtime/local/kernels/EwBinaryMat.h | 64 +++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/runtime/local/kernels/EwBinaryMat.h b/src/runtime/local/kernels/EwBinaryMat.h index 5330c71d7..80934d48d 100644 --- a/src/runtime/local/kernels/EwBinaryMat.h +++ b/src/runtime/local/kernels/EwBinaryMat.h @@ -341,3 +341,67 @@ template struct EwBinaryMat, Matrix, Matrix> { res->finishAppend(); } }; + +// ---------------------------------------------------------------------------- +// DenseMatrix <- CSRMatrix, DenseMatrix +// ---------------------------------------------------------------------------- + +template struct EwBinaryMat, CSRMatrix, DenseMatrix> { + static void apply(BinaryOpCode opCode, DenseMatrix *&res, const CSRMatrix *lhs, const DenseMatrix *rhs, + DCTX(ctx)) { + const size_t numRows = lhs->getNumRows(); + const size_t numCols = lhs->getNumCols(); + // TODO: lhs broadcast + // if ((numRows != rhs->getNumRows() && rhs->getNumRows() != 1) || + // (numCols != rhs->getNumCols() && rhs->getNumCols() != 1)) + // throw std::runtime_error("EwBinaryMat(CSR) - lhs and rhs must have " + // "the same dimensions (or broadcast)"); + if (numRows != rhs->getNumRows() || numCols != rhs->getNumCols()) + throw std::runtime_error("EwBinaryMat(CSR) - lhs and rhs must have " + "the same dimensions (or broadcast)"); + + size_t maxNnz; + switch (opCode) { + case BinaryOpCode::ADD: // merge + maxNnz = lhs->getNumNonZeros(); + break; + default: + throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode"); + } + + if (res == nullptr) + res = DataObjectFactory::create>(numRows, numCols, false); + + auto *valuesRes = res->getValues(); + auto *valuesRhs = rhs->getValues(); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) + valuesRes[r * numCols + c] = valuesRhs[r * numCols + c]; + + EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); + + switch (opCode) { + case BinaryOpCode::ADD: { // merge non-zero cells + for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) { + size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx); + if (nnzRowLhs) { + // merge within row + const VT *valuesRowLhs = lhs->getValues(rowIdx); + const size_t *colIdxsRowLhs = lhs->getColIdxs(rowIdx); + for (size_t posLhs = 0; posLhs < nnzRowLhs; ++posLhs) { + auto rhsCol = colIdxsRowLhs[posLhs]; + valuesRes[rhsCol] = func(valuesRes[rhsCol], valuesRowLhs[posLhs], ctx); + } + } + valuesRes += res->getRowSkip(); + } + break; + } + default: + throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode"); + } + + // TODO Update number of non-zeros in result in the end. + } +}; From b1713922f1f70661338e0c9956fa0791a48f4a88 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sat, 8 Feb 2025 00:06:06 +0100 Subject: [PATCH 19/33] update kernels.json --- src/runtime/local/kernels/kernels.json | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 0847b4052..90abc8cd8 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -2122,6 +2122,16 @@ ["DenseMatrix", "int64_t"], ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"] + ], + [ + ["DenseMatrix", "double"], + ["CSRMatrix", "double"], + ["DenseMatrix", "double"] + ], + [ + ["DenseMatrix", "float"], + ["CSRMatrix", "float"], + ["DenseMatrix", "float"] ] ], "opCodes": [ From 960a4c28f9601c70e224d2250906f5582d299739 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 10 Feb 2025 17:43:53 +0100 Subject: [PATCH 20/33] add CSR +/* CSR --- src/compiler/lowering/EwOpsLowering.cpp | 587 ++++++++++++++++++++++++ 1 file changed, 587 insertions(+) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index bdd8e6af3..38948dd80 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -487,6 +487,584 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { } } + LogicalResult matchAndRewriteSparseSparseMat(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + auto lhsMatrixType = lhs.getType().template dyn_cast(); + auto rhsMatrixType = rhs.getType().template dyn_cast(); + + ssize_t lhsRows = lhsMatrixType.getNumRows(); + ssize_t lhsCols = lhsMatrixType.getNumCols(); + ssize_t rhsRows = rhsMatrixType.getNumRows(); + ssize_t rhsCols = rhsMatrixType.getNumCols(); + + if (lhsRows != rhsRows || lhsCols != rhsCols) + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp Sparse Sparse)", "lhs and rhs must have the same dimensions."); + + auto numRows = lhsRows; + auto numCols = lhsCols; + + MemRefType lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, lhsMatrixType.getElementType()); + MemRefType rhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, rhsMatrixType.getElementType()); + MemRefType colIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType rowOffsetsMemRefType = + MemRefType::get({numRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, colIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, rowOffsetsMemRefType, lhs); + auto rhsValuesMemRef = + rewriter.create(loc, rhsValuesMemRefType, rhs); + auto rhsColIdxsMemRef = + rewriter.create(loc, colIdxsMemRefType, rhs); + auto rhsRowOffsetsMemRef = + rewriter.create(loc, rowOffsetsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numRowsValue = rewriter.create(loc, numRows); + + auto resValuesMemRef = rewriter.create(loc, lhsValuesMemRefType, ValueRange{one}); + auto resColIdxsMemRef = rewriter.create(loc, colIdxsMemRefType, ValueRange{one}); + auto resRowOffsetsMemRef = rewriter.create(loc, rowOffsetsMemRefType); + rewriter.create(loc, zero, resRowOffsetsMemRef, ValueRange{zero}); + // size_t rowOffset = 0; + // size_t resValuesPtr = 0; + + rewriter.create( + loc, zero, numRowsValue, one, ValueRange{zero}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + + auto resValuesPtr = loopIterArgs[0]; + + auto lhsColIdxLowerIncl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxUpperExcl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + auto rhsColIdxLowerIncl = OpBuilderNested.create( + locNested, rhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto rhsColIdxUpperExcl = OpBuilderNested.create( + locNested, rhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto lhsValuePtr = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto rhsValuePtr = OpBuilderNested.create( + locNested, rhsRowOffsetsMemRef, ValueRange{rowPtr}); + + // auto cmpRes1 = OpBuilderNested.create( + // locNested, arith::CmpIPredicate::ult, lhsColIdxUpperExcl, rhsColIdxLowerIncl); + // auto cmpRes2 = OpBuilderNested.create( + // locNested, arith::CmpIPredicate::ult, rhsColIdxUpperExcl, lhsColIdxLowerIncl); + // auto cmpRes = OpBuilderNested.create(locNested, cmpRes1, cmpRes2); + + // auto lhsColUpperExcl = OpBuilderNested.create( + // locNested, lhsColIdxsMemRef, ValueRange{lhsColIdxUpperExcl}); + + auto lhsColIdxUpperIncl = OpBuilderNested.create( + locNested, lhsColIdxUpperExcl, one); + auto lhsColUpper = OpBuilderNested.create( + locNested, lhsColIdxsMemRef, ValueRange{lhsColIdxUpperIncl}); + auto rhsColIdxUpperIncl = OpBuilderNested.create( + locNested, rhsColIdxUpperExcl, one); + auto rhsColUpper = OpBuilderNested.create( + locNested, rhsColIdxsMemRef, ValueRange{rhsColIdxUpperIncl}); + + // auto rhsColUpperExcl = OpBuilderNested.create( + // locNested, lhsColIdxsMemRef, ValueRange{rhsColIdxUpperExcl}); + // auto lhsEndFirst = OpBuilderNested.create( + // locNested, arith::CmpIPredicate::ult, lhsColUpperExcl, rhsColUpperExcl); + // auto rhsEndFirst = OpBuilderNested.create( + // locNested, arith::CmpIPredicate::uge, lhsColUpperExcl, rhsColUpperExcl); + + auto lhsEndFirst = OpBuilderNested.create( + locNested, arith::CmpIPredicate::ult, lhsColIdxLowerIncl, lhsColIdxUpperExcl); + auto rhsEndFirst = OpBuilderNested.create( + locNested, arith::CmpIPredicate::uge, rhsColIdxLowerIncl, rhsColIdxUpperExcl); + + //auto restValuesPtr = rewriter.create(loc, 0); + //Value restValuesPtr; + //Value newArg0; + //Value newArg1; + auto lhsAllZero = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + auto rhsAllZero = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + + auto operation = OpBuilderNested.create( + locNested, lhsAllZero, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + auto thenRegion = OpBuilderTwiceNested.create( + locTwiceNested, rhsAllZero, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto forLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsColIdxLowerIncl, rhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + // auto resIndex = OpBuilderFourtimesNested.create( + // locFourtimesNested, resValuesPtr); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{thenRegion.getResult(0)}); + }, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + auto elseRegion = OpBuilderTwiceNested.create( + locNested, rhsAllZero, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto forLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxLowerIncl, lhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + // auto resIndex = OpBuilderFourtimesNested.create( + // locFourtimesNested, resValuesPtr); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto whileLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, + TypeRange{ + OpBuilderThreetimesNested.getIndexType(), + OpBuilderThreetimesNested.getIndexType(), + OpBuilderThreetimesNested.getIndexType()}, + ValueRange{lhsColIdxLowerIncl, rhsColIdxLowerIncl, resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, ValueRange args) + { + auto cond1 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, args[0], lhsColIdxUpperExcl); + auto cond2 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, args[1], rhsColIdxUpperExcl); + auto cond = OpBuilderFourtimesNested.create(locNested, cond1, cond2); + OpBuilderFourtimesNested.create(locFourtimesNested, cond, args); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, ValueRange args) + { + auto lhsCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{args[0]}); + auto rhsCol = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsColIdxsMemRef, ValueRange{args[1]}); + + auto case1 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); + auto case2 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); + auto case3 = OpBuilderFourtimesNested.create( + locFourtimesNested, arith::CmpIPredicate::eq, lhsCol, rhsCol); + + auto newArg = OpBuilderFourtimesNested.create( + locFourtimesNested, case1, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested) + { + auto newResValuesPtr = args[2]; + if (llvm::isa(op)) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{args[0]}); + // auto resIndex = OpBuilderFivetimesNested.create( + // locFivetimesNested, resValuesPtr); + auto resIndex = args[2]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, lhsCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + } + + auto newArg0 = OpBuilderFivetimesNested.create(locFivetimesNested, args[0], one); + auto newArg1 = args[1]; + OpBuilderFivetimesNested.create( + locFivetimesNested, + ValueRange{newArg0, newArg1, newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested) + { + auto case2Region = OpBuilderFivetimesNested.create( + locFivetimesNested, case2, + [&](OpBuilder &OpBuilderSixtimesNested, Location locSixtimesNested) + { + auto newResValuesPtr = args[2]; + if (llvm::isa(op)) + { + auto resValue = OpBuilderSixtimesNested.create( + locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + // auto resIndex = OpBuilderSixtimesNested.create( + // locSixtimesNested, resValuesPtr); + auto resIndex = args[2]; + OpBuilderSixtimesNested.create( + locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderSixtimesNested.create( + locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + newResValuesPtr = OpBuilderSixtimesNested.create( + locSixtimesNested, resIndex, one); + } + auto newArg0 = args[0]; + auto newArg1 = OpBuilderSixtimesNested.create(locSixtimesNested, args[1], one); + OpBuilderSixtimesNested.create( + locSixtimesNested, ValueRange{newArg0, newArg1, newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderSixtimesNested, Location locSixtimesNested) + { + auto lhsValue = OpBuilderSixtimesNested.create( + locSixtimesNested, lhsValuesMemRef, ValueRange{args[0]}); + auto rhsValue = OpBuilderSixtimesNested.create( + locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + auto resValue = binaryFunc( + OpBuilderSixtimesNested, locSixtimesNested, this->typeConverter, lhsValue, rhsValue); + // auto resIndex = rewriter.create(locSixtimesNested, resValuesPtr); + auto resIndex = args[2]; + OpBuilderSixtimesNested.create( + locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderSixtimesNested.create( + locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + auto newResValuesPtr = OpBuilderSixtimesNested.create( + locSixtimesNested, resIndex, one); + auto newArg0 = OpBuilderSixtimesNested.create(locSixtimesNested, args[0], one); + auto newArg1 = OpBuilderSixtimesNested.create(locSixtimesNested, args[1], one); + OpBuilderSixtimesNested.create( + locSixtimesNested, ValueRange{newArg0, newArg1, newResValuesPtr}); + } + ); + OpBuilderFivetimesNested.create( + locFivetimesNested, + ValueRange{case2Region.getResult(0), case2Region.getResult(1), case2Region.getResult(2)}); + } + ); + auto newArg0 = newArg.getResult(0); + auto newArg1 = newArg.getResult(1); + auto newArg2 = newArg.getResult(2); + + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newArg0, newArg1, newArg2}); + } + ); + auto rest = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsEndFirst, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto rhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); + // auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{rhsRest.getResult(0)}); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto lhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + // auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + // resValuesPtr++; + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{lhsRest.getResult(0)}); + } + ); + /* OpBuilderThreetimesNested.create( + locThreetimesNested, rhsEndFirst, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopInvariants) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + OpBuilderFivetimesNested.create(locFivetimesNested); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested); + } + ); */ + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{elseRegion.getResult(0)}); + } + ); + + /* auto whileLoop = OpBuilderNested.create( + locNested, TypeRange{OpBuilderNested.getIndexType(), OpBuilderNested.getIndexType()}, ValueRange{lhsColIdxLowerIncl, rhsColIdxLowerIncl}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, ValueRange args) + { + + auto cond1 = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpIPredicate::ult, args[0], lhsColIdxUpperExcl); + // OpBuilderTwiceNested.create( + // locTwiceNested, cond1, + // [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested){ + // // restValuesPtr = args[1]; + // restValuesPtr = OpBuilderThreetimesNested.create( + // locThreetimesNested, one, args[1]); + // }); + auto cond2 = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpIPredicate::ult, args[1], rhsColIdxUpperExcl); + // OpBuilderTwiceNested.create( + // locTwiceNested, cond2, + // [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested){ + // // restValuesPtr = args[0]; + // restValuesPtr = OpBuilderThreetimesNested.create( + // locThreetimesNested, one, args[0]); + // }); + auto cond = OpBuilderTwiceNested.create(locNested, cond1, cond2); + OpBuilderTwiceNested.create(locTwiceNested, cond, args); + }, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, ValueRange args) + { + auto lhsCol = OpBuilderTwiceNested.create( + locTwiceNested, lhsColIdxsMemRef, ValueRange{args[0]}); + auto rhsCol = OpBuilderTwiceNested.create( + locTwiceNested, rhsColIdxsMemRef, ValueRange{args[1]}); + + auto case1 = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); + auto case2 = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); + auto case3 = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpIPredicate::eq, lhsCol, rhsCol); + + auto newArg = OpBuilderTwiceNested.create( + locTwiceNested, case1, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + if (llvm::isa(op)) + { + auto resValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{args[0]}); + auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); + OpBuilderThreetimesNested.create( + locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderThreetimesNested.create( + locThreetimesNested, lhsCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + } + auto newArg0 = OpBuilderThreetimesNested.create(locThreetimesNested, args[0], one); + auto newArg1 = args[1]; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newArg0, newArg1}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto then = OpBuilderThreetimesNested.create( + locThreetimesNested, case2, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + if (llvm::isa(op)) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + auto resIndex = OpBuilderFourtimesNested.create(locFourtimesNested, resValuesPtr); + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + } + auto newArg0 = args[0]; + auto newArg1 = OpBuilderFourtimesNested.create(locFourtimesNested, args[1], one); + OpBuilderFourtimesNested.create(locThreetimesNested, ValueRange{newArg0, newArg1}); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto lhsValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{args[0]}); + auto rhsValue = OpBuilderFourtimesNested.create( + locFourtimesNested, rhsValuesMemRef, ValueRange{args[1]}); + auto resValue = binaryFunc( + OpBuilderFourtimesNested, locFourtimesNested, this->typeConverter, lhsValue, rhsValue); + auto resIndex = rewriter.create(locFourtimesNested, resValuesPtr); + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + auto newArg0 = OpBuilderFourtimesNested.create(locFourtimesNested, args[0], one); + auto newArg1 = OpBuilderFourtimesNested.create(locFourtimesNested, args[1], one); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{newArg0, newArg1}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{then.getResult(0), then.getResult(1)}); + } + ); + auto newArg0 = newArg.getResult(0); + auto newArg1 = newArg.getResult(1); + + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{newArg0, newArg1}); + } + ); + OpBuilderNested.create( + // locNested, TypeRange{}, lhsEndFirst, true, + locNested, lhsEndFirst, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + OpBuilderTwiceNested.create( + //locTwiceNested, restValuesPtr, rhsColIdxUpperExcl, one, ValueRange{}, + locTwiceNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopInvariants) + { + auto resValue = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); + OpBuilderThreetimesNested.create( + locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderThreetimesNested.create( + locThreetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + OpBuilderThreetimesNested.create(locThreetimesNested); + } + ); + OpBuilderTwiceNested.create(locTwiceNested); + } + ); + + OpBuilderNested.create( + locNested, rhsEndFirst, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + OpBuilderTwiceNested.create( + // locTwiceNested, restValuesPtr, lhsColIdxUpperExcl, one, ValueRange{}, + locTwiceNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopInvariants) + { + auto resValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); + OpBuilderThreetimesNested.create( + locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderThreetimesNested.create( + locThreetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + resValuesPtr++; + OpBuilderThreetimesNested.create(locThreetimesNested); + } + ); + OpBuilderTwiceNested.create(locTwiceNested); + } + );*/ + + OpBuilderNested.create( + locNested, + // OpBuilderNested.create(loc, resValuesPtr), + operation.getResult(0), + resRowOffsetsMemRef, + ValueRange{nextRowPtr}); + + OpBuilderNested.create(locNested, ValueRange{operation.getResult(0)}); + } + ); + + Value maxNumRowsValue = rewriter.create(loc, numRows); + Value numColsValue = rewriter.create(loc, numCols); + Value maxNumNonZerosValue = rewriter.create(loc, numCols * numRows); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resValuesMemRef, resColIdxsMemRef, resRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } + LogicalResult matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.getLhs(); @@ -508,6 +1086,10 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) + return matchAndRewriteSparseSparseMat(op, adaptor, rewriter); + Type matrixElementType = lhsMatrixType.getElementType(); ssize_t lhsRows = lhsMatrixType.getNumRows(); @@ -832,6 +1414,11 @@ void EwOpLoweringPass::runOnOperation() { return false; } + if ((lhsMatType && lhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse) && + (rhsMatType && rhsMatType.getRepresentation() == daphne::MatrixRepresentation::Sparse)) { + return false; + } + return true; }); From 4752fdb036d196c48d6022db498333f2bac22106 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 10 Feb 2025 18:00:06 +0100 Subject: [PATCH 21/33] add EwBinaryMat CSR <- (CSR, CSR) --- src/compiler/lowering/EwOpsLowering.cpp | 235 ------------------------ 1 file changed, 235 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 38948dd80..ae90a79a6 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -537,8 +537,6 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto resColIdxsMemRef = rewriter.create(loc, colIdxsMemRefType, ValueRange{one}); auto resRowOffsetsMemRef = rewriter.create(loc, rowOffsetsMemRefType); rewriter.create(loc, zero, resRowOffsetsMemRef, ValueRange{zero}); - // size_t rowOffset = 0; - // size_t resValuesPtr = 0; rewriter.create( loc, zero, numRowsValue, one, ValueRange{zero}, @@ -558,20 +556,6 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto rhsColIdxUpperExcl = OpBuilderNested.create( locNested, rhsRowOffsetsMemRef, ValueRange{nextRowPtr}); - auto lhsValuePtr = OpBuilderNested.create( - locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); - auto rhsValuePtr = OpBuilderNested.create( - locNested, rhsRowOffsetsMemRef, ValueRange{rowPtr}); - - // auto cmpRes1 = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::ult, lhsColIdxUpperExcl, rhsColIdxLowerIncl); - // auto cmpRes2 = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::ult, rhsColIdxUpperExcl, lhsColIdxLowerIncl); - // auto cmpRes = OpBuilderNested.create(locNested, cmpRes1, cmpRes2); - - // auto lhsColUpperExcl = OpBuilderNested.create( - // locNested, lhsColIdxsMemRef, ValueRange{lhsColIdxUpperExcl}); - auto lhsColIdxUpperIncl = OpBuilderNested.create( locNested, lhsColIdxUpperExcl, one); auto lhsColUpper = OpBuilderNested.create( @@ -581,22 +565,9 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto rhsColUpper = OpBuilderNested.create( locNested, rhsColIdxsMemRef, ValueRange{rhsColIdxUpperIncl}); - // auto rhsColUpperExcl = OpBuilderNested.create( - // locNested, lhsColIdxsMemRef, ValueRange{rhsColIdxUpperExcl}); - // auto lhsEndFirst = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::ult, lhsColUpperExcl, rhsColUpperExcl); - // auto rhsEndFirst = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::uge, lhsColUpperExcl, rhsColUpperExcl); - auto lhsEndFirst = OpBuilderNested.create( locNested, arith::CmpIPredicate::ult, lhsColIdxLowerIncl, lhsColIdxUpperExcl); - auto rhsEndFirst = OpBuilderNested.create( - locNested, arith::CmpIPredicate::uge, rhsColIdxLowerIncl, rhsColIdxUpperExcl); - //auto restValuesPtr = rewriter.create(loc, 0); - //Value restValuesPtr; - //Value newArg0; - //Value newArg1; auto lhsAllZero = OpBuilderNested.create( locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); auto rhsAllZero = OpBuilderNested.create( @@ -623,14 +594,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFourtimesNested.create( locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - // auto resIndex = OpBuilderFourtimesNested.create( - // locFourtimesNested, resValuesPtr); auto resIndex = loopIterArgs[0]; OpBuilderFourtimesNested.create( locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderFourtimesNested.create( locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; auto newResValuesPtr = OpBuilderFourtimesNested.create( locFourtimesNested, resIndex, one); OpBuilderFourtimesNested.create( @@ -656,14 +624,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFourtimesNested.create( locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - // auto resIndex = OpBuilderFourtimesNested.create( - // locFourtimesNested, resValuesPtr); auto resIndex = loopIterArgs[0]; OpBuilderFourtimesNested.create( locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderFourtimesNested.create( locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; auto newResValuesPtr = OpBuilderFourtimesNested.create( locFourtimesNested, resIndex, one); OpBuilderFourtimesNested.create( @@ -701,8 +666,6 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); auto case2 = OpBuilderFourtimesNested.create( locFourtimesNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); - auto case3 = OpBuilderFourtimesNested.create( - locFourtimesNested, arith::CmpIPredicate::eq, lhsCol, rhsCol); auto newArg = OpBuilderFourtimesNested.create( locFourtimesNested, case1, @@ -713,14 +676,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { { auto resValue = OpBuilderFivetimesNested.create( locFivetimesNested, lhsValuesMemRef, ValueRange{args[0]}); - // auto resIndex = OpBuilderFivetimesNested.create( - // locFivetimesNested, resValuesPtr); auto resIndex = args[2]; OpBuilderFivetimesNested.create( locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderFivetimesNested.create( locFivetimesNested, lhsCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; newResValuesPtr = OpBuilderFivetimesNested.create( locFivetimesNested, resIndex, one); } @@ -742,14 +702,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { { auto resValue = OpBuilderSixtimesNested.create( locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); - // auto resIndex = OpBuilderSixtimesNested.create( - // locSixtimesNested, resValuesPtr); auto resIndex = args[2]; OpBuilderSixtimesNested.create( locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderSixtimesNested.create( locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; newResValuesPtr = OpBuilderSixtimesNested.create( locSixtimesNested, resIndex, one); } @@ -766,13 +723,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locSixtimesNested, rhsValuesMemRef, ValueRange{args[1]}); auto resValue = binaryFunc( OpBuilderSixtimesNested, locSixtimesNested, this->typeConverter, lhsValue, rhsValue); - // auto resIndex = rewriter.create(locSixtimesNested, resValuesPtr); auto resIndex = args[2]; OpBuilderSixtimesNested.create( locSixtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderSixtimesNested.create( locSixtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; auto newResValuesPtr = OpBuilderSixtimesNested.create( locSixtimesNested, resIndex, one); auto newArg0 = OpBuilderSixtimesNested.create(locSixtimesNested, args[0], one); @@ -806,13 +761,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFivetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFivetimesNested.create( locFivetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); - // auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); auto resIndex = loopIterArgs[0]; OpBuilderFivetimesNested.create( locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderFivetimesNested.create( locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; auto newResValuesPtr = OpBuilderFivetimesNested.create( locFivetimesNested, resIndex, one); OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); @@ -830,13 +783,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFivetimesNested.create( locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - // auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); auto resIndex = loopIterArgs[0]; OpBuilderFivetimesNested.create( locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); OpBuilderFivetimesNested.create( locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - // resValuesPtr++; auto newResValuesPtr = OpBuilderFivetimesNested.create( locFivetimesNested, resIndex, one); OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); @@ -845,30 +796,6 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{lhsRest.getResult(0)}); } ); - /* OpBuilderThreetimesNested.create( - locThreetimesNested, rhsEndFirst, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) - { - OpBuilderFourtimesNested.create( - locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{}, - [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopInvariants) - { - auto resValue = OpBuilderFivetimesNested.create( - locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderFivetimesNested.create( - locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = OpBuilderFivetimesNested.create(locFivetimesNested, resValuesPtr); - OpBuilderFivetimesNested.create( - locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFivetimesNested.create( - locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - OpBuilderFivetimesNested.create(locFivetimesNested); - } - ); - OpBuilderFourtimesNested.create(locFourtimesNested); - } - ); */ OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); } ); @@ -876,170 +803,8 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { } ); - /* auto whileLoop = OpBuilderNested.create( - locNested, TypeRange{OpBuilderNested.getIndexType(), OpBuilderNested.getIndexType()}, ValueRange{lhsColIdxLowerIncl, rhsColIdxLowerIncl}, - [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, ValueRange args) - { - - auto cond1 = OpBuilderTwiceNested.create( - locTwiceNested, arith::CmpIPredicate::ult, args[0], lhsColIdxUpperExcl); - // OpBuilderTwiceNested.create( - // locTwiceNested, cond1, - // [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested){ - // // restValuesPtr = args[1]; - // restValuesPtr = OpBuilderThreetimesNested.create( - // locThreetimesNested, one, args[1]); - // }); - auto cond2 = OpBuilderTwiceNested.create( - locTwiceNested, arith::CmpIPredicate::ult, args[1], rhsColIdxUpperExcl); - // OpBuilderTwiceNested.create( - // locTwiceNested, cond2, - // [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested){ - // // restValuesPtr = args[0]; - // restValuesPtr = OpBuilderThreetimesNested.create( - // locThreetimesNested, one, args[0]); - // }); - auto cond = OpBuilderTwiceNested.create(locNested, cond1, cond2); - OpBuilderTwiceNested.create(locTwiceNested, cond, args); - }, - [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, ValueRange args) - { - auto lhsCol = OpBuilderTwiceNested.create( - locTwiceNested, lhsColIdxsMemRef, ValueRange{args[0]}); - auto rhsCol = OpBuilderTwiceNested.create( - locTwiceNested, rhsColIdxsMemRef, ValueRange{args[1]}); - - auto case1 = OpBuilderTwiceNested.create( - locTwiceNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); - auto case2 = OpBuilderTwiceNested.create( - locTwiceNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); - auto case3 = OpBuilderTwiceNested.create( - locTwiceNested, arith::CmpIPredicate::eq, lhsCol, rhsCol); - - auto newArg = OpBuilderTwiceNested.create( - locTwiceNested, case1, - [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) - { - if (llvm::isa(op)) - { - auto resValue = OpBuilderThreetimesNested.create( - locThreetimesNested, lhsValuesMemRef, ValueRange{args[0]}); - auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); - OpBuilderThreetimesNested.create( - locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderThreetimesNested.create( - locThreetimesNested, lhsCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - } - auto newArg0 = OpBuilderThreetimesNested.create(locThreetimesNested, args[0], one); - auto newArg1 = args[1]; - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newArg0, newArg1}); - }, - [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) - { - auto then = OpBuilderThreetimesNested.create( - locThreetimesNested, case2, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) - { - if (llvm::isa(op)) - { - auto resValue = OpBuilderFourtimesNested.create( - locFourtimesNested, rhsValuesMemRef, ValueRange{args[1]}); - auto resIndex = OpBuilderFourtimesNested.create(locFourtimesNested, resValuesPtr); - OpBuilderFourtimesNested.create( - locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFourtimesNested.create( - locFourtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - } - auto newArg0 = args[0]; - auto newArg1 = OpBuilderFourtimesNested.create(locFourtimesNested, args[1], one); - OpBuilderFourtimesNested.create(locThreetimesNested, ValueRange{newArg0, newArg1}); - }, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) - { - auto lhsValue = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsValuesMemRef, ValueRange{args[0]}); - auto rhsValue = OpBuilderFourtimesNested.create( - locFourtimesNested, rhsValuesMemRef, ValueRange{args[1]}); - auto resValue = binaryFunc( - OpBuilderFourtimesNested, locFourtimesNested, this->typeConverter, lhsValue, rhsValue); - auto resIndex = rewriter.create(locFourtimesNested, resValuesPtr); - OpBuilderFourtimesNested.create( - locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFourtimesNested.create( - locFourtimesNested, rhsCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - auto newArg0 = OpBuilderFourtimesNested.create(locFourtimesNested, args[0], one); - auto newArg1 = OpBuilderFourtimesNested.create(locFourtimesNested, args[1], one); - OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{newArg0, newArg1}); - } - ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{then.getResult(0), then.getResult(1)}); - } - ); - auto newArg0 = newArg.getResult(0); - auto newArg1 = newArg.getResult(1); - - OpBuilderTwiceNested.create(locTwiceNested, ValueRange{newArg0, newArg1}); - } - ); - OpBuilderNested.create( - // locNested, TypeRange{}, lhsEndFirst, true, - locNested, lhsEndFirst, - [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) - { - OpBuilderTwiceNested.create( - //locTwiceNested, restValuesPtr, rhsColIdxUpperExcl, one, ValueRange{}, - locTwiceNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{}, - [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopInvariants) - { - auto resValue = OpBuilderThreetimesNested.create( - locThreetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderThreetimesNested.create( - locThreetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); - OpBuilderThreetimesNested.create( - locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderThreetimesNested.create( - locThreetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - OpBuilderThreetimesNested.create(locThreetimesNested); - } - ); - OpBuilderTwiceNested.create(locTwiceNested); - } - ); - - OpBuilderNested.create( - locNested, rhsEndFirst, - [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) - { - OpBuilderTwiceNested.create( - // locTwiceNested, restValuesPtr, lhsColIdxUpperExcl, one, ValueRange{}, - locTwiceNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{}, - [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopInvariants) - { - auto resValue = OpBuilderThreetimesNested.create( - locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderThreetimesNested.create( - locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = OpBuilderThreetimesNested.create(loc, resValuesPtr); - OpBuilderThreetimesNested.create( - locThreetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderThreetimesNested.create( - locThreetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - resValuesPtr++; - OpBuilderThreetimesNested.create(locThreetimesNested); - } - ); - OpBuilderTwiceNested.create(locTwiceNested); - } - );*/ - OpBuilderNested.create( locNested, - // OpBuilderNested.create(loc, resValuesPtr), operation.getResult(0), resRowOffsetsMemRef, ValueRange{nextRowPtr}); From c18516c72b91b7b8e843d27cbc52d6cf059b69f8 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 10 Feb 2025 18:58:47 +0100 Subject: [PATCH 22/33] fix bug --- src/compiler/lowering/EwOpsLowering.cpp | 92 ++++++++++++++----------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index ae90a79a6..3d2cf29ce 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -586,26 +586,33 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { }, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) { - auto forLoop = OpBuilderThreetimesNested.create( - locThreetimesNested, rhsColIdxLowerIncl, rhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) - { - auto resValue = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = loopIterArgs[0]; - OpBuilderFourtimesNested.create( - locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFourtimesNested.create( - locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - auto newResValuesPtr = OpBuilderFourtimesNested.create( - locFourtimesNested, resIndex, one); - OpBuilderFourtimesNested.create( - locFourtimesNested, ValueRange{newResValuesPtr}); - } - ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + if (llvm::isa(op)){ + auto forLoop = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsColIdxLowerIncl, rhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + } + else + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + } } ); OpBuilderTwiceNested.create(locTwiceNested, ValueRange{thenRegion.getResult(0)}); @@ -616,26 +623,33 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locNested, rhsAllZero, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) { - auto forLoop = OpBuilderThreetimesNested.create( + if (llvm::isa(op)){ + auto forLoop = OpBuilderThreetimesNested.create( locThreetimesNested, lhsColIdxLowerIncl, lhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) - { - auto resValue = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = loopIterArgs[0]; - OpBuilderFourtimesNested.create( - locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFourtimesNested.create( - locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - auto newResValuesPtr = OpBuilderFourtimesNested.create( - locFourtimesNested, resIndex, one); - OpBuilderFourtimesNested.create( - locFourtimesNested, ValueRange{newResValuesPtr}); - } - ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFourtimesNested.create( + locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFourtimesNested.create( + locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFourtimesNested.create( + locFourtimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFourtimesNested.create( + locFourtimesNested, resIndex, one); + OpBuilderFourtimesNested.create( + locFourtimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + } + else + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + } }, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) { From a8c3268cdeacdf4488dc518e01e4c3d183ad198c Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 10 Feb 2025 20:11:21 +0100 Subject: [PATCH 23/33] fix bug --- src/compiler/lowering/EwOpsLowering.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 3d2cf29ce..2f51856fd 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -565,13 +565,17 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto rhsColUpper = OpBuilderNested.create( locNested, rhsColIdxsMemRef, ValueRange{rhsColIdxUpperIncl}); + auto lhsEndFirst = OpBuilderNested.create( - locNested, arith::CmpIPredicate::ult, lhsColIdxLowerIncl, lhsColIdxUpperExcl); + // locNested, arith::CmpIPredicate::ult, lhsColIdxLowerIncl, lhsColIdxUpperExcl); + locNested, arith::CmpIPredicate::ult, lhsColUpper, rhsColUpper); auto lhsAllZero = OpBuilderNested.create( - locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + // locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + locNested, arith::CmpIPredicate::eq, lhsColIdxLowerIncl, lhsColIdxUpperExcl); auto rhsAllZero = OpBuilderNested.create( - locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + // locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); + locNested, arith::CmpIPredicate::eq, rhsColIdxLowerIncl, rhsColIdxUpperExcl); auto operation = OpBuilderNested.create( locNested, lhsAllZero, @@ -592,9 +596,9 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) { auto resValue = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + locFourtimesNested, rhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFourtimesNested.create( - locFourtimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + locFourtimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); auto resIndex = loopIterArgs[0]; OpBuilderFourtimesNested.create( locFourtimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); From 4a8fd6a385aeb37de6cc7b5f21d0d3dcf043664e Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Tue, 11 Feb 2025 23:22:19 +0100 Subject: [PATCH 24/33] add Matmul Lowering for (CSR, Dense) --- src/compiler/lowering/MatMulOpLowering.cpp | 112 +++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index c0aef46db..03f0ca9a0 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -301,12 +301,124 @@ class MatMulLowering : public OpConversionPattern { return loops; } + template + Value binaryWithConversionFunc(OpBuilder &rewriter, Location loc, TypeConverter *typeConverter, Value lhs, Value rhs) const { + Type resType = lhs.getType(); + Value res{}; + if (llvm::isa(resType)) { + lhs = convertToSignlessInt(rewriter, loc, typeConverter, lhs, resType); + rhs = convertToSignlessInt(rewriter, loc, typeConverter, rhs, resType); + res = rewriter.create(loc, lhs, rhs).getResult(); + res = typeConverter->materializeTargetConversion(rewriter, loc, resType, res); + } else { + res = rewriter.create(loc, lhs, rhs).getResult(); + } + return res; + } + + LogicalResult matchAndRewriteSparseDenseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + mlir::daphne::MatrixType lhsMatrixType = lhs.getType().dyn_cast(); + mlir::daphne::MatrixType rhsMatrixType = rhs.getType().dyn_cast(); + + auto lhsRows = lhsMatrixType.getNumRows(); + auto lhsCols = lhsMatrixType.getNumCols(); + + auto rhsRows = rhsMatrixType.getNumRows(); + auto rhsCols = rhsMatrixType.getNumCols(); + + auto matrixElementType = lhsMatrixType.getElementType(); + + auto lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto lhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto lhsRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + auto rhsMemRefType = mlir::MemRefType::get({rhsRows, rhsCols}, matrixElementType); + auto resMemRefType = mlir::MemRefType::get({lhsRows, rhsCols}, matrixElementType); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, lhs); + auto rhsMemRef = + rewriter.create(loc, rhsMemRefType, rhs); + auto resMemRef = rewriter.create(loc, resMemRefType); + + auto zeroElement = rewriter.create(loc, rewriter.getZeroAttr(matrixElementType)); + rewriter.create(loc, ValueRange{zeroElement}, ValueRange{resMemRef}); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numLhsRowsValue = rewriter.create(loc, lhsRows); + auto numRhsColsValue = rewriter.create(loc, rhsCols); + + auto lhsRowLoop = rewriter.create( + loc, zero, numLhsRowsValue, one, ValueRange{}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, rowPtr, one); + auto rhsColLoop = OpBuilderNested.create( + locNested, zero, numRhsColsValue, one, ValueRange{}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rhsCol = loopIdx; + auto lhsColIdxsLowerIncl = OpBuilderTwiceNested.create( + locTwiceNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxsUpperExcl = OpBuilderTwiceNested.create( + locTwiceNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto resValueLoop = OpBuilderTwiceNested.create( + locTwiceNested, lhsColIdxsLowerIncl, lhsColIdxsUpperExcl, one, ValueRange{zeroElement}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto lhsValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto rhsRow = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto rhsValue = OpBuilderThreetimesNested.create( + locThreetimesNested, rhsMemRef, ValueRange{rhsRow, rhsCol}); + + auto resValue = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, lhsValue, rhsValue); + + auto accResValue = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, loopIterArgs[0], resValue); + + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{accResValue}); + } + ); + OpBuilderTwiceNested.create( + locTwiceNested, resValueLoop.getResult(0), resMemRef, ValueRange{rowPtr, rhsCol}); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{}); + } + ); + OpBuilderNested.create(locNested, ValueRange{}); + } + ); + // Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemRef, op.getType()); + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemRef, rhs.getType()); + rewriter.replaceOp(op, resDenseMatrix); + return mlir::success(); + } + LogicalResult matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); mlir::daphne::MatrixType lhsMatrixType = adaptor.getLhs().getType().dyn_cast(); mlir::daphne::MatrixType rhsMatrixType = adaptor.getRhs().getType().dyn_cast(); + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) + return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + auto lhsRows = lhsMatrixType.getNumRows(); auto lhsCols = lhsMatrixType.getNumCols(); From a3bbd9ebb0b0e2e3f581e500c37f5e0434a45a61 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 17 Feb 2025 16:07:50 +0100 Subject: [PATCH 25/33] add Matmul for (CSR, CSR), correct EwOpsLowering --- src/compiler/lowering/EwOpsLowering.cpp | 6 +- src/compiler/lowering/MatMulOpLowering.cpp | 215 +++++++++++++++++++++ 2 files changed, 218 insertions(+), 3 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 2f51856fd..207f54bb3 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -610,7 +610,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, ValueRange{newResValuesPtr}); } ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); } else { @@ -647,7 +647,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, ValueRange{newResValuesPtr}); } ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{forLoop.getResult(0)}); } else { @@ -670,7 +670,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, arith::CmpIPredicate::ult, args[0], lhsColIdxUpperExcl); auto cond2 = OpBuilderFourtimesNested.create( locFourtimesNested, arith::CmpIPredicate::ult, args[1], rhsColIdxUpperExcl); - auto cond = OpBuilderFourtimesNested.create(locNested, cond1, cond2); + auto cond = OpBuilderFourtimesNested.create(locFourtimesNested, cond1, cond2); OpBuilderFourtimesNested.create(locFourtimesNested, cond, args); }, [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, ValueRange args) diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index 03f0ca9a0..cae582aa2 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -315,6 +315,217 @@ class MatMulLowering : public OpConversionPattern { } return res; } + + template + Value cmpWithConversionFunc(OpBuilder &rewriter, Location loc, TypeConverter *typeConverter, Value lhs, Value rhs) const { + Type resType = lhs.getType(); + Value res{}; + if (llvm::isa(resType)) { + lhs = convertToSignlessInt(rewriter, loc, typeConverter, lhs, resType); + rhs = convertToSignlessInt(rewriter, loc, typeConverter, rhs, resType); + res = rewriter.create(loc, cmpIPredicate, lhs, rhs).getResult(); + res = typeConverter->materializeTargetConversion(rewriter, loc, resType, res); + } else { + res = rewriter.create(loc, cmpFPredicate, lhs, rhs).getResult(); + } + return res; + } + + Value csrIndex(OpBuilder &rewriter, Location loc, + Value valuesMemRef, Value colIdxsMemRef, Value rowOffsetsMemRef, Value row, Value col, Type type) const + { + auto zeroElem = rewriter.create(loc, rewriter.getZeroAttr(type)); + auto one = rewriter.create(loc, 1); + auto rowPtr = row; + auto nextRowPtr = rewriter.create(loc, row, one); + auto colIdxLowerIncl = rewriter.create( + loc, rowOffsetsMemRef, ValueRange{rowPtr}); + auto colIdxUpperExcl = rewriter.create( + loc, rowOffsetsMemRef, ValueRange{nextRowPtr}); + auto search = rewriter.create( + loc, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{zeroElem}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto getCol = OpBuilderNested.create(locNested, colIdxsMemRef, ValueRange{loopIdx}); + auto getValue = OpBuilderNested.create(locNested, valuesMemRef, ValueRange{loopIdx}); + auto cond = OpBuilderNested.create(locNested, arith::CmpIPredicate::eq, getCol, col); + auto res = OpBuilderNested.create( + locNested, cond, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{getValue}); + }, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) + { + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{zeroElem}); + } + ); + OpBuilderNested.create(locNested, res.getResult(0)); + } + ); + return search.getResult(0); + } + + LogicalResult matchAndRewriteSparseSparseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + mlir::daphne::MatrixType lhsMatrixType = lhs.getType().dyn_cast(); + mlir::daphne::MatrixType rhsMatrixType = rhs.getType().dyn_cast(); + + auto lhsRows = lhsMatrixType.getNumRows(); + auto lhsCols = lhsMatrixType.getNumCols(); + + auto rhsRows = rhsMatrixType.getNumRows(); + auto rhsCols = rhsMatrixType.getNumCols(); + + auto matrixElementType = lhsMatrixType.getElementType(); + + auto lhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto lhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto lhsRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + auto rhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto rhsColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto rhsRowOffsetsMemRefType = + MemRefType::get({rhsRows + 1}, rewriter.getIndexType()); + auto resValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, matrixElementType); + auto resColIdxsMemRefType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + auto resRowOffsetsMemRefType = + MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); + + auto lhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, lhs); + auto lhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, lhs); + auto lhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, lhs); + auto rhsValuesMemRef = + rewriter.create(loc, lhsValuesMemRefType, rhs); + auto rhsColIdxsMemRef = + rewriter.create(loc, lhsColIdxsMemRefType, rhs); + auto rhsRowOffsetsMemRef = + rewriter.create(loc, lhsRowOffsetsMemRefType, rhs); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto zeroElement = rewriter.create(loc, rewriter.getZeroAttr(matrixElementType)); + auto numLhsRowsValue = rewriter.create(loc, lhsRows); + auto numRhsRowsValue = rewriter.create(loc, rhsRows); + auto numLhsColsValue = rewriter.create(loc, lhsCols); + auto numRhsColsValue = rewriter.create(loc, rhsCols); + + auto resValuesMemRef = rewriter.create(loc, resValuesMemRefType, ValueRange{one}); + auto resColIdxsMemRef = rewriter.create(loc, resColIdxsMemRefType, ValueRange{one}); + auto resRowOffsetsMemRef = rewriter.create(loc, resRowOffsetsMemRefType); + rewriter.create(loc, zero, resRowOffsetsMemRef, ValueRange{zero}); + + auto lhsRowLoop = rewriter.create( + loc, zero, numLhsRowsValue, one, ValueRange{zero}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + { + auto rowPtr = loopIdx; + auto lhsRow = loopIdx; + auto nextRowPtr = OpBuilderNested.create(locNested, loopIdx, one); + auto resValuesPtr = loopIterArgs[0]; + + auto lhsColIdxLowerIncl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{rowPtr}); + auto lhsColIdxUpperExcl = OpBuilderNested.create( + locNested, lhsRowOffsetsMemRef, ValueRange{nextRowPtr}); + + auto rhsColLoop = OpBuilderNested.create( + locNested, zero, numRhsColsValue, one, ValueRange{resValuesPtr}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValuesPtr = loopIterArgs[0]; + auto rhsCol = loopIdx; + auto lhsColLoop = OpBuilderTwiceNested.create( + locTwiceNested, lhsColIdxLowerIncl, lhsColIdxUpperExcl, one, ValueRange{zeroElement}, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto acc = loopIterArgs[0]; + + auto lhsElemRow = lhsRow; + auto lhsElemCol = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto lhsElemValue = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + + auto rhsElemRow = lhsElemCol; + auto rhsElemCol = rhsCol; + auto rhsElemValue = csrIndex( + OpBuilderThreetimesNested, locThreetimesNested, + rhsValuesMemRef, rhsColIdxsMemRef, rhsRowOffsetsMemRef, + rhsElemRow, rhsElemCol, matrixElementType); + + auto product = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, lhsElemValue, rhsElemValue); + auto newAcc = binaryWithConversionFunc( + OpBuilderThreetimesNested, locThreetimesNested, this->typeConverter, product, acc); + + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newAcc}); + + } + ); + auto cond = OpBuilderTwiceNested.create( + locTwiceNested, arith::CmpFPredicate::OEQ, lhsColLoop.getResult(0), zeroElement); + auto newPtr = OpBuilderTwiceNested.create( + locTwiceNested, cond, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto newResValuesPtr = resValuesPtr; + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); + }, + [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) + { + auto newResValuesPtr = OpBuilderThreetimesNested.create( + locThreetimesNested, resValuesPtr, one); + OpBuilderThreetimesNested.create( + locThreetimesNested, lhsColLoop.getResult(0), resValuesMemRef, ValueRange{resValuesPtr}); + OpBuilderThreetimesNested.create( + locThreetimesNested, rhsCol, resColIdxsMemRef, ValueRange{resValuesPtr}); + OpBuilderThreetimesNested.create( + locThreetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{newPtr.getResult(0)}); + } + ); + auto newResValuesPtr = rhsColLoop.getResult(0); + + OpBuilderNested.create( + locNested, + newResValuesPtr, + resRowOffsetsMemRef, + ValueRange{nextRowPtr}); + OpBuilderNested.create(locNested, ValueRange{newResValuesPtr}); + } + ); + + Value maxNumRowsValue = rewriter.create(loc, lhsRows); + Value numColsValue = rewriter.create(loc, rhsCols); + Value maxNumNonZerosValue = rewriter.create(loc, lhsRows * rhsCols); + + Value resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, + resValuesMemRef, resColIdxsMemRef, resRowOffsetsMemRef, + maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + + if (!resCSRMatrix) { + llvm::errs() << "Error: resCSRMatrix is null!\n"; + } + + rewriter.replaceOp(op, resCSRMatrix); + + return mlir::success(); + } LogicalResult matchAndRewriteSparseDenseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); @@ -418,6 +629,10 @@ class MatMulLowering : public OpConversionPattern { if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Dense) return matchAndRewriteSparseDenseMat(op, adaptor, rewriter); + + if (lhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse && + rhsMatrixType.getRepresentation() == daphne::MatrixRepresentation::Sparse) + return matchAndRewriteSparseSparseMat(op, adaptor, rewriter); auto lhsRows = lhsMatrixType.getNumRows(); auto lhsCols = lhsMatrixType.getNumCols(); From d4a7cdb0a3d5dcc002d086e2a6fb1e41964e6bb3 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Sun, 23 Feb 2025 16:04:21 +0100 Subject: [PATCH 26/33] add a script level test case for gemm codegen --- scripts/examples/gemm-codegen.daph | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 scripts/examples/gemm-codegen.daph diff --git a/scripts/examples/gemm-codegen.daph b/scripts/examples/gemm-codegen.daph new file mode 100644 index 000000000..3bf641822 --- /dev/null +++ b/scripts/examples/gemm-codegen.daph @@ -0,0 +1,15 @@ +# bench.daph +size=$size; +sparsity=$sparsity; + +alpha = 2; +beta = 3; +A = rand(size, size, 1.0, 1.0, sparsity, -1); +B = rand(size, size, 1.0, 1.0, sparsity, -1); +C = rand(size, size, 1.0, 1.0, sparsity, -1); +start = now(); +D = beta * C + alpha * A @ B ; +end = now(); +print((end-start) / 1000000000.0); +x = aggMax(D); +print(x); \ No newline at end of file From e2476a2000d9790bc5f9b755f7601c85eddab067 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 24 Feb 2025 12:28:31 +0100 Subject: [PATCH 27/33] clean comment-outs --- src/compiler/lowering/EwOpsLowering.cpp | 38 ++-------------------- src/compiler/lowering/MatMulOpLowering.cpp | 2 +- src/runtime/local/kernels/EwUnaryMat.h | 12 +++++-- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 207f54bb3..23fb9f728 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -88,14 +88,12 @@ template struct UnaryOpLowering : publi ssize_t numCols = sparseMatType.getNumCols(); if (numRows < 0 || numCols < 0) { - std::cout<<"here 5"<( @@ -116,8 +114,6 @@ template struct UnaryOpLowering : publi OpBuilderNested.create(locNested, resValue); }); - - //rewriter.replaceOp(op, resMemref); MemRefType sparseColIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); MemRefType sparseRowOffsetsMemRefType = MemRefType::get({numRows + 1}, rewriter.getIndexType()); @@ -129,12 +125,11 @@ template struct UnaryOpLowering : publi Value maxNumRowsValue = rewriter.create(loc, numRows); Value numColsValue = rewriter.create(loc, numCols); Value maxNumNonZerosValue = rewriter.create(loc, numCols * numRows); - //auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, resMemref, op.getType()); auto resCSRMatrix = convertMemRefToCSRMatrix(loc, rewriter, resMemref, argColIdxsMemref, argRowOffsetsMemref, - maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); - //maxNumRowsValue, numColsValue, maxNumNonZerosValue, adaptor.getArg().getType()); + //maxNumRowsValue, numColsValue, maxNumNonZerosValue, op.getType()); + maxNumRowsValue, numColsValue, maxNumNonZerosValue, adaptor.getArg().getType()); rewriter.replaceOp(op, resCSRMatrix); @@ -160,7 +155,6 @@ template struct UnaryOpLowering : publi ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - std::cout<<"here 6"< { if (lhsRows != 1 && rhsRows == 1) { // rhs is a row vector, broadcast along columns if (lhsCols != rhsCols) { - std::cout<<"here 7"< { } else if (lhsCols != 1 && rhsCols == 1) { // rhs is a column vector, broadcast along rows if (lhsRows != rhsRows) { - std::cout<<"here 8"< { } else { // rhs is not broadcasted, return identity mapping if (lhsRows != rhsRows || lhsCols != rhsCols) { - std::cout<<"here 9"< { auto resSparseMemRef = rewriter.create(loc, sparseLhsValuesMemRefType, ValueRange{one}); rewriter.create( - // loc, rowPtr, nextRowPtr, rewriter.create(loc, 1), loc, zero, numSparseLhsRowsValue, one, ValueRange{}, - // [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx) [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopInvariants) { auto rowPtr = loopIdx; @@ -402,12 +391,9 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locNested, sparseLhsRowOffsetsMemRef, ValueRange{nextRowPtr}); OpBuilderNested.create( - // locNested, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{rowPtr}, locNested, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{}, - // [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, ValueRange loopInvariantsNested) [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, ValueRange loopInvariants) { - // auto rowIdx = loopInvariantsNested[0]; auto rowIdx = rowPtr; auto colIdx = OpBuilderTwiceNested.create( locTwiceNested, sparseLhsColIdxsMemRef, ValueRange{loopIdxNested}); @@ -420,33 +406,24 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { Value resValue = binaryFunc( OpBuilderTwiceNested, locTwiceNested, this->typeConverter, sparseLhsValue, denseRhsValue); - - //Value store; if (llvm::isa(op)) { - // auto store = OpBuilderTwiceNested.create( OpBuilderTwiceNested.create( locTwiceNested, resValue, resDenseMemRef, ValueRange{rowIdx, colIdx}); } else if (llvm::isa(op)) { - // auto store = OpBuilderTwiceNested.create( OpBuilderTwiceNested.create( locTwiceNested, resValue, resSparseMemRef, ValueRange{loopIdxNested}); } else { - std::cout<<"here 10"<(locTwiceNested, resValue); - // OpBuilderTwiceNested.create(locTwiceNested); } ); - - // OpBuilderNested.create(locNested, resValue); - // auto resValue = colLoop.getResult(0); OpBuilderNested.create(locNested); } ); @@ -454,14 +431,12 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { if (llvm::isa(op)) { Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resDenseMemRef, op.getType()); - std::cout<<"here 1"<(op)) { - llvm::errs()<(loc, sparseLhsRows); Value numColsValue = rewriter.create(loc, sparseLhsCols); Value maxNumNonZerosValue = rewriter.create(loc, sparseLhsCols * sparseLhsRows); @@ -473,16 +448,11 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { if (!resCSRMatrix) { llvm::errs() << "Error: resCSRMatrix is null!\n"; } - std::cout<<"here 2"< { auto lhsEndFirst = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::ult, lhsColIdxLowerIncl, lhsColIdxUpperExcl); locNested, arith::CmpIPredicate::ult, lhsColUpper, rhsColUpper); auto lhsAllZero = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); locNested, arith::CmpIPredicate::eq, lhsColIdxLowerIncl, lhsColIdxUpperExcl); auto rhsAllZero = OpBuilderNested.create( - // locNested, arith::CmpIPredicate::eq, lhsColUpper, rhsColUpper); locNested, arith::CmpIPredicate::eq, rhsColIdxLowerIncl, rhsColIdxUpperExcl); auto operation = OpBuilderNested.create( @@ -881,7 +848,6 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ssize_t rhsCols = rhsMatrixType.getNumCols(); if (lhsRows < 0 || lhsCols < 0 || rhsRows < 0 || rhsCols < 0) { - std::cout<<"here 4"< { OpBuilderNested.create(locNested, ValueRange{}); } ); - // Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemRef, op.getType()); + Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemRef, rhs.getType()); rewriter.replaceOp(op, resDenseMatrix); return mlir::success(); diff --git a/src/runtime/local/kernels/EwUnaryMat.h b/src/runtime/local/kernels/EwUnaryMat.h index d438a2ba5..ac34789b6 100644 --- a/src/runtime/local/kernels/EwUnaryMat.h +++ b/src/runtime/local/kernels/EwUnaryMat.h @@ -110,10 +110,18 @@ template struct EwUnaryMat, CSRMatrix> { res = DataObjectFactory::create>(numRows, numCols, maxNumNonZeros, false); const VT *valuesArg = arg->getValues(); + const size_t *colIdxsArg = arg->getColIdxs(); + const size_t *rowOffsetsArg = arg->getRowOffsets(); + VT *valuesRes = res->getValues(); + size_t *colIdxsRes = res->getColIdxs(); + size_t *rowOffsetsRes = res->getRowOffsets(); + + for (size_t i = 0; i < numNonZeros; i++) + colIdxsRes[i] = colIdxsArg[i]; - res->getColIdxsSharedPtr() = arg->getColIdxsSharedPtr(); - res->getRowOffsetsSharedPtr() = arg->getRowOffsetsSharedPtr(); + for (size_t i = 0; i < numRows + 1; i++) + rowOffsetsRes[i] = rowOffsetsArg[i]; EwUnaryScaFuncPtr func = getEwUnaryScaFuncPtr(opCode); From d5361d4f22ec50a03ed20c7316c9fd944c5707b5 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 24 Feb 2025 21:45:59 +0100 Subject: [PATCH 28/33] add comments and fix a bug --- src/compiler/lowering/EwOpsLowering.cpp | 112 +++++++++++++----------- 1 file changed, 63 insertions(+), 49 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 23fb9f728..52bde1432 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -374,8 +374,10 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto one = rewriter.create(loc, 1); auto numSparseLhsRowsValue = rewriter.create(loc, sparseLhsRows); + // return a dense matrix if the op is add auto resDenseMemRef = rewriter.create(loc, denseRhsMemRefType); rewriter.create(loc, denseRhsMemRef, resDenseMemRef); + // return a sparse matrix if the op is mul auto resSparseMemRef = rewriter.create(loc, sparseLhsValuesMemRefType, ValueRange{one}); rewriter.create( @@ -552,16 +554,19 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locTwiceNested, rhsAllZero, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) { + //if lhs and rhs are all-zero in this row, move to next row auto newResValuesPtr = resValuesPtr; OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); }, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) { + //if lhs is all-zero in this row but rhs is not if (llvm::isa(op)){ auto forLoop = OpBuilderThreetimesNested.create( locThreetimesNested, rhsColIdxLowerIncl, rhsColIdxUpperExcl, one, ValueRange{resValuesPtr}, [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested, Value loopIdx, ValueRange loopIterArgs) { + //copy this row of rhs to the res memref if the op is add auto resValue = OpBuilderFourtimesNested.create( locFourtimesNested, rhsValuesMemRef, ValueRange{loopIdx}); auto resCol = OpBuilderFourtimesNested.create( @@ -581,6 +586,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { } else { + //else move to the next row auto newResValuesPtr = resValuesPtr; OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{newResValuesPtr}); } @@ -651,7 +657,8 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, arith::CmpIPredicate::ult, lhsCol, rhsCol); auto case2 = OpBuilderFourtimesNested.create( locFourtimesNested, arith::CmpIPredicate::ult, rhsCol, lhsCol); - + // copy the element whose col num is smaller to the res if the op is add + // then load the next element of that side auto newArg = OpBuilderFourtimesNested.create( locFourtimesNested, case1, [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested) @@ -702,6 +709,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { }, [&](OpBuilder &OpBuilderSixtimesNested, Location locSixtimesNested) { + //perform computation on elements if their num col is equal to each other auto lhsValue = OpBuilderSixtimesNested.create( locSixtimesNested, lhsValuesMemRef, ValueRange{args[0]}); auto rhsValue = OpBuilderSixtimesNested.create( @@ -734,54 +742,60 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { locFourtimesNested, ValueRange{newArg0, newArg1, newArg2}); } ); - auto rest = OpBuilderThreetimesNested.create( - locThreetimesNested, lhsEndFirst, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) - { - auto rhsRest = OpBuilderFourtimesNested.create( - locFourtimesNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, - [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) - { - auto resValue = OpBuilderFivetimesNested.create( - locFivetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderFivetimesNested.create( - locFivetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = loopIterArgs[0]; - OpBuilderFivetimesNested.create( - locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFivetimesNested.create( - locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - auto newResValuesPtr = OpBuilderFivetimesNested.create( - locFivetimesNested, resIndex, one); - OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); - } - ); - OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{rhsRest.getResult(0)}); - }, - [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) - { - auto lhsRest = OpBuilderFourtimesNested.create( - locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, - [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) - { - auto resValue = OpBuilderFivetimesNested.create( - locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); - auto resCol = OpBuilderFivetimesNested.create( - locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); - auto resIndex = loopIterArgs[0]; - OpBuilderFivetimesNested.create( - locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); - OpBuilderFivetimesNested.create( - locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); - auto newResValuesPtr = OpBuilderFivetimesNested.create( - locFivetimesNested, resIndex, one); - OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); - } - ); - OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{lhsRest.getResult(0)}); - } - ); - OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); + // if lhs ends first, the rest will be in rhs and copy them to the res if the op is add + if (llvm::isa(op)) + { + auto rest = OpBuilderThreetimesNested.create( + locThreetimesNested, lhsEndFirst, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto rhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(1), rhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, rhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{rhsRest.getResult(0)}); + }, + [&](OpBuilder &OpBuilderFourtimesNested, Location locFourtimesNested) + { + auto lhsRest = OpBuilderFourtimesNested.create( + locFourtimesNested, whileLoop.getResult(0), lhsColIdxUpperExcl, one, ValueRange{whileLoop.getResult(2)}, + [&](OpBuilder &OpBuilderFivetimesNested, Location locFivetimesNested, Value loopIdx, ValueRange loopIterArgs) + { + auto resValue = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsValuesMemRef, ValueRange{loopIdx}); + auto resCol = OpBuilderFivetimesNested.create( + locFivetimesNested, lhsColIdxsMemRef, ValueRange{loopIdx}); + auto resIndex = loopIterArgs[0]; + OpBuilderFivetimesNested.create( + locFivetimesNested, resValue, resValuesMemRef, ValueRange{resIndex}); + OpBuilderFivetimesNested.create( + locFivetimesNested, resCol, resColIdxsMemRef, ValueRange{resIndex}); + auto newResValuesPtr = OpBuilderFivetimesNested.create( + locFivetimesNested, resIndex, one); + OpBuilderFivetimesNested.create(locFivetimesNested, ValueRange{newResValuesPtr}); + } + ); + OpBuilderFourtimesNested.create(locFourtimesNested, ValueRange{lhsRest.getResult(0)}); + } + ); + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); + } + else + OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{whileLoop.getResult(2)}); } ); OpBuilderTwiceNested.create(locTwiceNested, ValueRange{elseRegion.getResult(0)}); From 1cbc9f854925ba3271b6b16869cf3e1f78c4ab1e Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 24 Feb 2025 21:59:43 +0100 Subject: [PATCH 29/33] edit the comments --- src/compiler/lowering/EwOpsLowering.cpp | 2 +- src/compiler/lowering/MatMulOpLowering.cpp | 7 +++++-- src/compiler/lowering/SliceOpLowering.cpp | 21 +++------------------ 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 52bde1432..64ecda24c 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -794,7 +794,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ); OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{rest.getResult(0)}); } - else + else //TODO: Support ops other than add OpBuilderThreetimesNested.create(locThreetimesNested, ValueRange{whileLoop.getResult(2)}); } ); diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index d1e98f955..1b8eba64f 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -300,7 +300,7 @@ class MatMulLowering : public OpConversionPattern { loops.push_back(fmaLoop); return loops; } - + template Value binaryWithConversionFunc(OpBuilder &rewriter, Location loc, TypeConverter *typeConverter, Value lhs, Value rhs) const { Type resType = lhs.getType(); @@ -330,7 +330,7 @@ class MatMulLowering : public OpConversionPattern { } return res; } - + Value csrIndex(OpBuilder &rewriter, Location loc, Value valuesMemRef, Value colIdxsMemRef, Value rowOffsetsMemRef, Value row, Value col, Type type) const { @@ -349,6 +349,7 @@ class MatMulLowering : public OpConversionPattern { auto getCol = OpBuilderNested.create(locNested, colIdxsMemRef, ValueRange{loopIdx}); auto getValue = OpBuilderNested.create(locNested, valuesMemRef, ValueRange{loopIdx}); auto cond = OpBuilderNested.create(locNested, arith::CmpIPredicate::eq, getCol, col); + // return the value of non-zero element if exists, else return a zero value auto res = OpBuilderNested.create( locNested, cond, [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) @@ -461,6 +462,7 @@ class MatMulLowering : public OpConversionPattern { auto rhsElemRow = lhsElemCol; auto rhsElemCol = rhsCol; + // locate the required element in rhs corresponding to the lhs element auto rhsElemValue = csrIndex( OpBuilderThreetimesNested, locThreetimesNested, rhsValuesMemRef, rhsColIdxsMemRef, rhsRowOffsetsMemRef, @@ -477,6 +479,7 @@ class MatMulLowering : public OpConversionPattern { ); auto cond = OpBuilderTwiceNested.create( locTwiceNested, arith::CmpFPredicate::OEQ, lhsColLoop.getResult(0), zeroElement); + // store the result if it is not zero auto newPtr = OpBuilderTwiceNested.create( locTwiceNested, cond, [&](OpBuilder &OpBuilderThreetimesNested, Location locThreetimesNested) diff --git a/src/compiler/lowering/SliceOpLowering.cpp b/src/compiler/lowering/SliceOpLowering.cpp index e92fd2b3c..57197a211 100644 --- a/src/compiler/lowering/SliceOpLowering.cpp +++ b/src/compiler/lowering/SliceOpLowering.cpp @@ -60,7 +60,6 @@ static constexpr size_t COL = 1; template class SliceOpLowering : public OpConversionPattern { public: - //using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpConversionPattern::OpAdaptor; explicit SliceOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) @@ -69,9 +68,9 @@ class SliceOpLowering : public OpConversionPattern { } /** - * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. + * @brief Replaces a Slice operation with a MemRef SubviewOp if possible. * - * @return mlir::success if Transpose has been replaced, else mlir::failure. + * @return mlir::success if Slice has been replaced, else mlir::failure. */ LogicalResult matchAndRewrite(SliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -104,17 +103,6 @@ class SliceOpLowering : public OpConversionPattern { DenseI64ArrayAttr sizes = sliceAlongDim == ROW ? rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}) : rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); - // if (sliceAlongDim == ROW) - // { - // DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({lowerIncl, 0}); - // DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({(upperExcl-lowerIncl), numCols}); - // } - // else - // { - // DenseI64ArrayAttr offset = rewriter.getDenseI64ArrayAttr({0, lowerIncl}); - // DenseI64ArrayAttr sizes = rewriter.getDenseI64ArrayAttr({numRows, (upperExcl-lowerIncl)}); - // } - DenseI64ArrayAttr strides = rewriter.getDenseI64ArrayAttr({1, 1}); Value resMemref = rewriter.create(loc, argMemref, offset, sizes, strides); @@ -132,10 +120,7 @@ using SliceColOpLowering = SliceOpLowering; namespace { /** - * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. - * - * This rewrite may enable loop fusion on the affine loops TransposeOp is - * lowered to by running the loop fusion pass. + * @brief Lowers the daphne::Slice operator to a Memref SubviewOp. */ struct SliceLoweringPass : public mlir::PassWrapper> { explicit SliceLoweringPass() {} From ee3959acc153376195ce380f4f51ad66cf88fdc7 Mon Sep 17 00:00:00 2001 From: WangYuyao <30401063+WangYuyao@users.noreply.github.com> Date: Mon, 24 Feb 2025 22:05:25 +0100 Subject: [PATCH 30/33] Delete src/compiler/lowering/ExtractOpLowering.cpp --- src/compiler/lowering/ExtractOpLowering.cpp | 205 -------------------- 1 file changed, 205 deletions(-) delete mode 100644 src/compiler/lowering/ExtractOpLowering.cpp diff --git a/src/compiler/lowering/ExtractOpLowering.cpp b/src/compiler/lowering/ExtractOpLowering.cpp deleted file mode 100644 index 3d6dd5977..000000000 --- a/src/compiler/lowering/ExtractOpLowering.cpp +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Copyright 2025 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "compiler/utils/LoweringUtils.h" -#include "ir/daphneir/Daphne.h" -#include "ir/daphneir/Passes.h" - -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/UseDefLists.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" - - -using namespace mlir; -using namespace std; - -static constexpr size_t ROW = 0; -static constexpr size_t COL = 1; - -template -class ExtractOpLowering : public OpConversionPattern { - public: - //using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpConversionPattern::OpAdaptor; - - explicit ExtractOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx, PatternBenefit(1)) { - this->setDebugName("ExtractOpLowering"); - } - - /** - * @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. - * - * @return mlir::success if Transpose has been replaced, else mlir::failure. - */ - LogicalResult matchAndRewrite(ExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - daphne::MatrixType matrixType = adaptor.getSource().getType().template dyn_cast(); - if (!matrixType) { - return failure(); - } - - Location loc = op->getLoc(); - - Type matrixElementType = matrixType.getElementType(); - ssize_t numRows = matrixType.getNumRows(); - ssize_t numCols = matrixType.getNumCols(); - - if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "extractOp codegen currently only works with matrix dimensions that are known at compile time"); - } - - Value argMemref = rewriter.create( - loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getSource()); - - - daphne::MatrixType selectionType = adaptor.getSelectedRows().getType().template dyn_cast(); - if (!matrixType) { - return failure(); - } - - - - Type selectionElementType = selectionType.getElementType(); - ssize_t numSelectedRows = selectionType.getNumRows(); - - Value selectionMemref = rewriter.create( - loc, MemRefType::get({numSelectedRows, 1}, matrixElementType), adaptor.getSelectedRows()); - - Value resMemref = rewriter.create(loc, MemRefType::get({numSelectedRows, numCols}, matrixElementType)); - - for (ssize_t i = 0; i < numSelectedRows; i++) - { - - Value des = rewriter.create(loc, resMemref, - rewriter.getDenseI64ArrayAttr({i, 0}), - rewriter.getDenseI64ArrayAttr({1, numCols}), - rewriter.getDenseI64ArrayAttr({1, 1})); - - Value select = rewriter.create(loc, selectionMemref, - ValueRange{rewriter.create(loc, i), - rewriter.create(loc, 0)}); - - Value zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - - ValueRange offsets = {select, zero}; - ValueRange sizes = {rewriter.create(loc, 1), - rewriter.create(loc, numCols)}; - ValueRange strides = {rewriter.create(loc, 1), - rewriter.create(loc, 1)}; - - Value src = rewriter.create(loc, argMemref, offsets, sizes, strides); - - rewriter.create(loc, src, des); - - } - - - Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); - - rewriter.replaceOp(op, resDenseMatrix); - - return success(); - } -}; - -using ExtractRowOpLowering = ExtractOpLowering; -//using ExtractColOpLowering = ExtractOpLowering; - -namespace { -/** - * @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. - * - * This rewrite may enable loop fusion on the affine loops TransposeOp is - * lowered to by running the loop fusion pass. - */ -struct ExtractLoweringPass : public mlir::PassWrapper> { - explicit ExtractLoweringPass() {} - - StringRef getArgument() const final { return "lower-extract"; } - StringRef getDescription() const final { return "Lowers ExtractRow/ExtractCol operators to a Memref SubViewOp."; } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() final; -}; -} // end anonymous namespace - -void ExtractLoweringPass::runOnOperation() { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - LowerToLLVMOptions llvmOptions(&getContext()); - LLVMTypeConverter typeConverter(&getContext(), llvmOptions); - - typeConverter.addConversion(convertInteger); - typeConverter.addConversion(convertFloat); - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addArgumentMaterialization(materializeCastFromIllegal); - typeConverter.addSourceMaterialization(materializeCastToIllegal); - typeConverter.addTargetMaterialization(materializeCastFromIllegal); - - target.addLegalDialect(); - - target.addDynamicallyLegalOp([](Operation *op) { - Type operand = op->getOperand(0).getType(); - daphne::MatrixType matType = operand.dyn_cast(); - if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { - return false; - } - return true; - }); - - patterns.insert(typeConverter, &getContext()); - auto module = getOperation(); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - signalPassFailure(); - } -} - -std::unique_ptr daphne::createExtractOpLoweringPass() { - return std::make_unique(); -} \ No newline at end of file From f743c81a8480360ca54ad88376037b36a84a1ffb Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Mon, 24 Feb 2025 22:25:51 +0100 Subject: [PATCH 31/33] remove ExtractOp lowering related --- src/compiler/lowering/CMakeLists.txt | 1 - src/ir/daphneir/Passes.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 937b07d1a..9b169396f 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -39,7 +39,6 @@ add_mlir_dialect_library(MLIRDaphneTransforms SliceRowOpLowering.cpp SliceColOpLowering.cpp SliceOpLowering.cpp - ExtractOpLowering.cpp DEPENDS MLIRDaphneOpsIncGen diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index 5ff3755bd..9f20e7ed1 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -75,7 +75,6 @@ std::unique_ptr createVectorizeComputationsPass(); std::unique_ptr createSliceRowOpLoweringPass(); std::unique_ptr createSliceColOpLoweringPass(); std::unique_ptr createSliceOpLoweringPass(); -std::unique_ptr createExtractOpLoweringPass(); #ifdef USE_CUDA std::unique_ptr createMarkCUDAOpsPass(const DaphneUserConfig &cfg); From df2796d9a836eaecdbf01e4e7d89059558e11f1c Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Thu, 27 Feb 2025 20:04:46 +0100 Subject: [PATCH 32/33] optimize CSR Matrix Index in MatMulOpLowering --- src/compiler/lowering/MatMulOpLowering.cpp | 46 ++++++++++++++-------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index 1b8eba64f..c4576290c 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -336,35 +336,54 @@ class MatMulLowering : public OpConversionPattern { { auto zeroElem = rewriter.create(loc, rewriter.getZeroAttr(type)); auto one = rewriter.create(loc, 1); + auto zero = rewriter.create(loc, 0); auto rowPtr = row; auto nextRowPtr = rewriter.create(loc, row, one); auto colIdxLowerIncl = rewriter.create( loc, rowOffsetsMemRef, ValueRange{rowPtr}); auto colIdxUpperExcl = rewriter.create( loc, rowOffsetsMemRef, ValueRange{nextRowPtr}); - auto search = rewriter.create( - loc, colIdxLowerIncl, colIdxUpperExcl, one, ValueRange{zeroElem}, - [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopIterArgs) + + auto search = rewriter.create( + loc,TypeRange{ + rewriter.getIndexType(), + rewriter.getIndexType(), + type}, + ValueRange{colIdxLowerIncl, one, zeroElem}, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange args) { - auto getCol = OpBuilderNested.create(locNested, colIdxsMemRef, ValueRange{loopIdx}); - auto getValue = OpBuilderNested.create(locNested, valuesMemRef, ValueRange{loopIdx}); + auto cond1 = OpBuilderNested.create( + locNested, arith::CmpIPredicate::ult, args[0], colIdxUpperExcl); + auto cond2 = OpBuilderNested.create( + locNested, arith::CmpIPredicate::eq, args[1], one); + auto cond = OpBuilderNested.create(locNested, cond1, cond2); + OpBuilderNested.create(locNested, cond, args); + }, + [&](OpBuilder &OpBuilderNested, Location locNested, ValueRange args) + { + auto getCol = OpBuilderNested.create(locNested, colIdxsMemRef, ValueRange{args[0]}); + auto cond = OpBuilderNested.create(locNested, arith::CmpIPredicate::eq, getCol, col); // return the value of non-zero element if exists, else return a zero value auto res = OpBuilderNested.create( locNested, cond, [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) { - OpBuilderTwiceNested.create(locTwiceNested, ValueRange{getValue}); + auto getValue = OpBuilderTwiceNested.create( + locTwiceNested, valuesMemRef, ValueRange{args[0]}); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{zero, getValue}); }, [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested) { - OpBuilderTwiceNested.create(locTwiceNested, ValueRange{zeroElem}); + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{one, zeroElem}); } ); - OpBuilderNested.create(locNested, res.getResult(0)); + auto nextPtr = OpBuilderNested.create(locNested, args[0], one); + OpBuilderNested.create(locNested, ValueRange{nextPtr, res.getResult(0), res.getResult(1)}); } ); - return search.getResult(0); + + return search.getResult(2); } LogicalResult matchAndRewriteSparseSparseMat(daphne::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -389,12 +408,6 @@ class MatMulLowering : public OpConversionPattern { MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); auto lhsRowOffsetsMemRefType = MemRefType::get({lhsRows + 1}, rewriter.getIndexType()); - auto rhsValuesMemRefType = - MemRefType::get({ShapedType::kDynamic}, matrixElementType); - auto rhsColIdxsMemRefType = - MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - auto rhsRowOffsetsMemRefType = - MemRefType::get({rhsRows + 1}, rewriter.getIndexType()); auto resValuesMemRefType = MemRefType::get({ShapedType::kDynamic}, matrixElementType); auto resColIdxsMemRefType = @@ -419,8 +432,6 @@ class MatMulLowering : public OpConversionPattern { auto one = rewriter.create(loc, 1); auto zeroElement = rewriter.create(loc, rewriter.getZeroAttr(matrixElementType)); auto numLhsRowsValue = rewriter.create(loc, lhsRows); - auto numRhsRowsValue = rewriter.create(loc, rhsRows); - auto numLhsColsValue = rewriter.create(loc, lhsCols); auto numRhsColsValue = rewriter.create(loc, rhsCols); auto resValuesMemRef = rewriter.create(loc, resValuesMemRefType, ValueRange{one}); @@ -462,6 +473,7 @@ class MatMulLowering : public OpConversionPattern { auto rhsElemRow = lhsElemCol; auto rhsElemCol = rhsCol; + //auto rhsElemCol = lhsElemRow; // locate the required element in rhs corresponding to the lhs element auto rhsElemValue = csrIndex( OpBuilderThreetimesNested, locThreetimesNested, From 23921b8f74dc001e83b8d26e041e860b3dd1b223 Mon Sep 17 00:00:00 2001 From: WangYuyao <924953347@qq.com> Date: Thu, 27 Feb 2025 21:07:09 +0100 Subject: [PATCH 33/33] add tests --- test/api/cli/codegen/SparsityLAOpsTest.cpp | 97 +++++++++++++++++++ .../cli/codegen/ewbinary_add_sparse.daphne | 6 ++ .../cli/codegen/ewbinary_mul_sparse.daphne | 7 ++ test/api/cli/codegen/ewunary_abs_sparse.daph | 2 + test/api/cli/codegen/matmul_sparse.daphne | 6 ++ 5 files changed, 118 insertions(+) create mode 100644 test/api/cli/codegen/SparsityLAOpsTest.cpp create mode 100644 test/api/cli/codegen/ewbinary_add_sparse.daphne create mode 100644 test/api/cli/codegen/ewbinary_mul_sparse.daphne create mode 100644 test/api/cli/codegen/ewunary_abs_sparse.daph create mode 100644 test/api/cli/codegen/matmul_sparse.daphne diff --git a/test/api/cli/codegen/SparsityLAOpsTest.cpp b/test/api/cli/codegen/SparsityLAOpsTest.cpp new file mode 100644 index 000000000..cc4d56796 --- /dev/null +++ b/test/api/cli/codegen/SparsityLAOpsTest.cpp @@ -0,0 +1,97 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +const std::string dirPath = "test/api/cli/codegen/"; + +TEST_CASE("ewUnary_abs, sparse", TAG_CODEGEN) { + std::string result = "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 1 1 0 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewunary_abs_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("ewBinary_add, sparse", TAG_CODEGEN) { + std::string result = "DenseMatrix(8x8, double)\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "1 1 1 2 1 1 1 1\n" + "1 1 1 1 1 1 1 1\n" + "CSRMatrix(8x8, double)\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 0 0 0 0 1\n" + "0 0 0 0 0 0 0 0\n" + "0 0 0 1 0 0 0 0\n" + "0 0 0 0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewbinary_add_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("ewBinary_mul, sparse", TAG_CODEGEN) { + std::string result = "CSRMatrix(5x5, double)\n" + "0 0 0 0 2\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 2 2 0 0\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 1 1 0 0\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "ewbinary_mul_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} + +TEST_CASE("matmul, sparse-dense", TAG_CODEGEN) { + std::string result = "DenseMatrix(5x5, double)\n" + "1 1 1 1 1\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "2 2 2 2 2\n" + "0 0 0 0 0\n" + "CSRMatrix(5x5, double)\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "0 0 0 0 0\n" + "1 0 0 1 0\n" + "0 0 0 0 0\n"; + + compareDaphneToStr(result, dirPath + "matmul_sparse.daphne", "--mlir-codegen", "--no-obj-ref-mgnt", "--select-matrix-repr"); +} diff --git a/test/api/cli/codegen/ewbinary_add_sparse.daphne b/test/api/cli/codegen/ewbinary_add_sparse.daphne new file mode 100644 index 000000000..ec1c7885f --- /dev/null +++ b/test/api/cli/codegen/ewbinary_add_sparse.daphne @@ -0,0 +1,6 @@ +W = rand(8, 8, 1.0, 1.0, 0.01, 1); +V = rand(8, 8, 1.0, 1.0, 0.01, 2); +X = rand(8, 8, 1.0, 1.0, 1, 3); + +print(W + X); // sparse + dense +print(W + V); // sparse + sparse \ No newline at end of file diff --git a/test/api/cli/codegen/ewbinary_mul_sparse.daphne b/test/api/cli/codegen/ewbinary_mul_sparse.daphne new file mode 100644 index 000000000..f4d369953 --- /dev/null +++ b/test/api/cli/codegen/ewbinary_mul_sparse.daphne @@ -0,0 +1,7 @@ +W = rand(5, 5, 1.0, 1.0, 0.1, 1); +V = rand(5, 5, 1.0, 1.0, 0.1, 2); +X = rand(5, 5, 1.0, 1.0, 1, 3); + +print(W * 2); // sparse * scalar +print(W * X); // sparse * dense +print(W * V); // sparse * sparse \ No newline at end of file diff --git a/test/api/cli/codegen/ewunary_abs_sparse.daph b/test/api/cli/codegen/ewunary_abs_sparse.daph new file mode 100644 index 000000000..0b11e40d7 --- /dev/null +++ b/test/api/cli/codegen/ewunary_abs_sparse.daph @@ -0,0 +1,2 @@ +W = rand(5, 5, -1.0, -1.0, 0.1, 1); +print(abs(W)); \ No newline at end of file diff --git a/test/api/cli/codegen/matmul_sparse.daphne b/test/api/cli/codegen/matmul_sparse.daphne new file mode 100644 index 000000000..3f3b19374 --- /dev/null +++ b/test/api/cli/codegen/matmul_sparse.daphne @@ -0,0 +1,6 @@ +W = rand(5, 5, 1.0, 1.0, 0.1, 1); +V = rand(5, 5, 1.0, 1.0, 0.1, 2); +X = rand(5, 5, 1.0, 1.0, 1, 3); + +print(W @ X); // sparse @ dense +print(W @ V); // sparse @ sparse