From 07af99212212ce6ee856e5d4235cb53d5e05573b Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 2 Apr 2025 18:12:54 +0200 Subject: [PATCH] fixing in-place and 0d all_reduce --- mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 34 ++++++++++++++++--- mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 8 +++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 9df5e992e8ebd..5575b295ae20a 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -15,8 +15,10 @@ #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Transforms/DialectConversion.h" #include @@ -57,9 +59,14 @@ std::pair getRawPtrAndSize(const Location loc, loc, rewriter.getI64Type(), memRef, 2); Value resPtr = rewriter.create(loc, ptrType, elType, dataPtr, offset); - Value size = rewriter.create(loc, memRef, - ArrayRef{3, 0}); - size = rewriter.create(loc, rewriter.getI32Type(), size); + Value size; + if (cast(memRef.getType()).getBody().size() > 3) { + size = rewriter.create(loc, memRef, + ArrayRef{3, 0}); + size = rewriter.create(loc, rewriter.getI32Type(), size); + } else { + size = rewriter.create(loc, 1, 32); + } return {resPtr, size}; } @@ -97,6 +104,9 @@ class MPIImplTraits { /// Get the MPI_STATUS_IGNORE value (typically a pointer type). virtual intptr_t getStatusIgnore() = 0; + /// Get the MPI_IN_PLACE value (void *). + virtual void *getInPlace() = 0; + /// Gets or creates an MPI datatype as a value which corresponds to the given /// type. virtual Value getDataType(const Location loc, @@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits { intptr_t getStatusIgnore() override { return 1; } + void *getInPlace() override { return reinterpret_cast(-1); } + Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, Type type) override { int32_t mtype = 0; @@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits { intptr_t getStatusIgnore() override { return 0; } + void *getInPlace() override { return reinterpret_cast(1); } + Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, Type type) override { StringRef mtype; @@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { outPtr.getRes()}); // load the communicator into a register - auto res = rewriter.create(loc, i32, outPtr.getResult()); + Value res = rewriter.create(loc, i32, outPtr.getResult()); + res = rewriter.create(loc, rewriter.getI64Type(), res); // if retval is checked, replace uses of retval with the results from the // call op @@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { replacements.push_back(callOp.getResult()); // replace op - replacements.push_back(res.getRes()); + replacements.push_back(res); rewriter.replaceOp(op, replacements); return success(); @@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { Location loc = op.getLoc(); MLIRContext *context = rewriter.getContext(); Type i32 = rewriter.getI32Type(); + Type i64 = rewriter.getI64Type(); Type elemType = op.getSendbuf().getType().getElementType(); // ptrType `!llvm.ptr` @@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType); auto [recvPtr, recvSize] = getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType); + + // If input and output are the same, request in-place operation. + if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { + sendPtr = rewriter.create( + loc, i64, reinterpret_cast(mpiTraits->getInPlace())); + sendPtr = rewriter.create(loc, ptrType, sendPtr); + } + Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp()); Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir index 174f7c79b9d50..35fc0f5d2e754 100644 --- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir +++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir @@ -98,10 +98,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32 + // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64 + // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 + // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> // CHECK: llvm.call @MPI_Finalize() : () -> i32 @@ -202,10 +204,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32 + // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 + // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32