From 1aa51f74277eace6dcaf6372ba645b4627548bb4 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 14 Aug 2024 19:29:23 +0200 Subject: [PATCH 01/15] initial hack lowering mesh.update_halo to MPI --- .../mlir/Conversion/MeshToMPI/MeshToMPI.h | 27 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 17 ++ mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 33 ++++ mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/MeshToMPI/CMakeLists.txt | 22 +++ mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 171 ++++++++++++++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 19 ++ .../MeshToMPI/convert-mesh-to-mpi.mlir | 34 ++++ 9 files changed, 325 insertions(+) create mode 100644 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h create mode 100644 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt create mode 100644 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp create mode 100644 mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h new file mode 100644 index 0000000000000..6a2c196da4557 --- /dev/null +++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h @@ -0,0 +1,27 @@ +//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H +#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS +#include "mlir/Conversion/Passes.h.inc" + +/// Lowers Mesh communication operations (updateHalo, AllGater, ...) +/// to MPI primitives. +std::unique_ptr createConvertMeshToMPIPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H \ No newline at end of file diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 2ab32836c80b1..b577aa83946f2 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -51,6 +51,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" +#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 4d272ba219c6f..83e0c5a06c43f 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -878,6 +878,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> { ]; } +//===----------------------------------------------------------------------===// +// MeshToMPI +//===----------------------------------------------------------------------===// + +def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { + let summary = "Convert Mesh dialect to MPI dialect."; + let description = [{ + This pass converts communication operations + from the Mesh dialect to operations from the MPI dialect. + }]; + let dependentDialects = [ + "memref::MemRefDialect", + "mpi::MPIDialect", + "scf::SCFDialect" + ]; +} + //===----------------------------------------------------------------------===// // NVVMToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 19498fe5a32d6..2c2b6e20f3654 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -156,6 +156,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ ]; } +def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ + Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = + "For given split axes get the linear index the direct neighbor processes."; + let description = [{ + Example: + ``` + %idx = mesh.neighbor_linear_index on @mesh for $device + split_axes = $split_axes : index + ``` + Given `@mesh` with shape `(10, 20, 30)`, + `device` = `(1, 2, 3)` + `$split_axes` = `[1]` + it returns the linear indices of the processes at positions `(1, 1, 3)`: `633` + and `(1, 3, 3)`: `693`. + + A negative value is returned if `$device` has no neighbor in the given + direction along the given `split_axes`. + }]; + let arguments = (ins FlatSymbolRefAttr:$mesh, + Variadic:$device, + Mesh_MeshAxesAttr:$split_axes); + let results = (outs Index:$neighbor_down, Index:$neighbor_up); + let assemblyFormat = [{ + `on` $mesh `[` $device `]` + `split_axes` `=` $split_axes + attr-dict `:` type(results) + }]; +} + //===----------------------------------------------------------------------===// // Sharding operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 6651d87162257..62461c0cea08a 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) +add_subdirectory(MeshToMPI) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt new file mode 100644 index 0000000000000..95815a683f6d6 --- /dev/null +++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMeshToMPI + MeshToMPI.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRIR + MLIRLinalgTransforms + MLIRMemRefDialect + MLIRPass + MLIRMeshDialect + MLIRMPIDialect + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp new file mode 100644 index 0000000000000..b4cf9da8497a2 --- /dev/null +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -0,0 +1,171 @@ +//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation of Mesh communicatin ops tp MPI ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +#define DEBUG_TYPE "mesh-to-mpi" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::mesh; + +namespace { +struct ConvertMeshToMPIPass + : public impl::ConvertMeshToMPIPassBase { + using Base::Base; + + /// Run the dialect converter on the module. + void runOnOperation() override { + getOperation()->walk([&](UpdateHaloOp op) { + SymbolTableCollection symbolTableCollection; + OpBuilder builder(op); + auto loc = op.getLoc(); + + auto toValue = [&builder, &loc](OpFoldResult &v) { + return v.is() + ? v.get() + : builder.create<::mlir::arith::ConstantOp>( + loc, + builder.getIndexAttr( + cast(v.get()).getInt())); + }; + + auto array = op.getInput(); + auto rank = array.getType().getRank(); + auto mesh = op.getMesh(); + auto meshOp = getMesh(op, symbolTableCollection); + auto haloSizes = getMixedValues(op.getStaticHaloSizes(), + op.getDynamicHaloSizes(), builder); + for (auto &sz : haloSizes) { + if (sz.is()) { + sz = builder + .create(loc, builder.getIndexType(), + sz.get()) + .getResult(); + } + } + + SmallVector offsets(rank, builder.getIndexAttr(0)); + SmallVector strides(rank, builder.getIndexAttr(1)); + SmallVector shape(rank); + for (auto [i, s] : llvm::enumerate(array.getType().getShape())) { + if (ShapedType::isDynamic(s)) { + shape[i] = builder.create(loc, array, s).getResult(); + } else { + shape[i] = builder.getIndexAttr(s); + } + } + + auto tagAttr = builder.getI32IntegerAttr(91); // whatever + auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr); + auto zeroAttr = builder.getI32IntegerAttr(0); // whatever + auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr); + SmallVector indexResultTypes(meshOp.getShape().size(), + builder.getIndexType()); + auto myMultiIndex = + builder.create(loc, indexResultTypes, mesh) + .getResult(); + auto currHaloDim = 0; + + for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) { + if (!splitAxes.empty()) { + auto tmp = builder + .create( + loc, mesh, myMultiIndex, splitAxes) + .getResults(); + Value neighbourIDs[2] = {builder.create( + loc, builder.getI32Type(), tmp[0]), + builder.create( + loc, builder.getI32Type(), tmp[1])}; + auto orgDimSize = shape[dim]; + auto upperOffset = builder.create( + loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1])); + + // make sure we send/recv in a way that does not lead to a dead-lock + // This is by far not optimal, this should be at least MPI_sendrecv + // and - probably even more importantly - buffers should be re-used + // Currently using temporary, contiguous buffer for MPI communication + auto genSendRecv = [&](auto dim, bool upperHalo) { + auto orgOffset = offsets[dim]; + shape[dim] = + upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2]; + auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; + auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; + auto hasFrom = builder.create( + loc, arith::CmpIPredicate::sge, from, zero); + auto hasTo = builder.create( + loc, arith::CmpIPredicate::sge, to, zero); + auto buffer = builder.create( + loc, shape, array.getType().getElementType()); + builder.create( + loc, hasTo, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo + ? OpFoldResult(builder.getIndexAttr(0)) + : OpFoldResult(upperOffset); + auto subview = builder.create( + loc, array, offsets, shape, strides); + builder.create(loc, subview, buffer); + builder.create(loc, TypeRange{}, buffer, tag, + to); + builder.create(loc); + }); + builder.create( + loc, hasFrom, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo + ? OpFoldResult(upperOffset) + : OpFoldResult(builder.getIndexAttr(0)); + builder.create(loc, TypeRange{}, buffer, tag, + from); + auto subview = builder.create( + loc, array, offsets, shape, strides); + builder.create(loc, buffer, subview); + builder.create(loc); + }); + builder.create(loc, buffer); + offsets[dim] = orgOffset; + }; + + genSendRecv(dim, false); + genSendRecv(dim, true); + + shape[dim] = builder + .create( + loc, toValue(orgDimSize), + builder + .create( + loc, toValue(haloSizes[dim * 2]), + toValue(haloSizes[dim * 2 + 1])) + .getResult()) + .getResult(); + offsets[dim] = haloSizes[dim * 2]; + ++currHaloDim; + } + } + }); + } +}; +} // namespace \ No newline at end of file diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index c5570d8ee8a44..33460ff25e9e4 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -837,6 +837,25 @@ void ProcessLinearIndexOp::getAsmResultNames( setNameFn(getResult(), "proc_linear_idx"); } +//===----------------------------------------------------------------------===// +// mesh.neighbors_linear_indices op +//===----------------------------------------------------------------------===// + +LogicalResult +NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + return success(); +} + +void NeighborsLinearIndicesOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getNeighborDown(), "down_linear_idx"); + setNameFn(getNeighborUp(), "up_linear_idx"); +} + //===----------------------------------------------------------------------===// // collective communication ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir new file mode 100644 index 0000000000000..9ef826ca0cdac --- /dev/null +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s + +// CHECK: mesh.mesh @mesh0 +mesh.mesh @mesh0(shape = 2x2x4) + +// ----- + +// CHECK-LABEL: func @update_halo +func.func @update_halo_1d( + // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> + %arg0 : memref<12x12xi8>) { + // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 + // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0 + // CHECK-SAME: split_axes = {{\[\[}}0]] + // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> + %c2 = arith.constant 2 : i64 + mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + halo_sizes = [2, %c2] : memref<12x12xi8> + return +} + +func.func @update_halo_2d( + // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> + %arg0 : memref<12x12xi8>) { + %c2 = arith.constant 2 : i64 + // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0 + // CHECK-SAME: split_axes = {{\[\[}}0], [1]] + // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] + // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8> + mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] + halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2] + : memref<12x12xi8> + return +} From 8b8c6e4a12e1301d126b6dcd78ae69a506a58e12 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Fri, 16 Aug 2024 10:55:28 +0200 Subject: [PATCH 02/15] dim fixes, proper testing --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 306 ++++++++++-------- .../MeshToMPI/convert-mesh-to-mpi.mlir | 179 ++++++++-- 2 files changed, 339 insertions(+), 146 deletions(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index b4cf9da8497a2..42d885a109ee7 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communicatin ops tp MPI ops. +// This file implements a translation of Mesh communication ops tp MPI ops. // //===----------------------------------------------------------------------===// @@ -21,6 +21,8 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "mesh-to-mpi" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -34,138 +36,190 @@ using namespace mlir; using namespace mlir::mesh; namespace { + +// This pattern converts the mesh.update_halo operation to MPI calls +struct ConvertUpdateHaloOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::UpdateHaloOp op, + mlir::PatternRewriter &rewriter) const override { + // Halos are exchanged as 2 blocks per dimension (one for each side: down + // and up). It is assumed that the last dim in a default memref is + // contiguous, hence iteration starts with the complete halo on the first + // dim which should be contiguous (unless the source is not). The size of + // the exchanged data will decrease when iterating over dimensions. That's + // good because the halos of last dim will be most fragmented. + // memref.subview is used to read and write the halo data from and to the + // local data. subviews and halos have dynamic and static values, so + // OpFoldResults are used whenever possible. + + SymbolTableCollection symbolTableCollection; + auto loc = op.getLoc(); + + // convert a OpFoldResult into a Value + auto toValue = [&rewriter, &loc](OpFoldResult &v) { + return v.is() + ? v.get() + : rewriter.create<::mlir::arith::ConstantOp>( + loc, + rewriter.getIndexAttr( + cast(v.get()).getInt())); + }; + + auto array = op.getInput(); + auto rank = array.getType().getRank(); + auto mesh = op.getMesh(); + auto meshOp = getMesh(op, symbolTableCollection); + auto haloSizes = getMixedValues(op.getStaticHaloSizes(), + op.getDynamicHaloSizes(), rewriter); + // subviews need Index values + for (auto &sz : haloSizes) { + if (sz.is()) { + sz = rewriter + .create(loc, rewriter.getIndexType(), + sz.get()) + .getResult(); + } + } + + // most of the offset/size/stride data is the same for all dims + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector shape(rank); + // we need the actual shape to compute offsets and sizes + for (auto [i, s] : llvm::enumerate(array.getType().getShape())) { + if (ShapedType::isDynamic(s)) { + shape[i] = rewriter.create(loc, array, s).getResult(); + } else { + shape[i] = rewriter.getIndexAttr(s); + } + } + + auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something + auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr); + auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 + auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr); + SmallVector indexResultTypes(meshOp.getShape().size(), + rewriter.getIndexType()); + auto myMultiIndex = + rewriter.create(loc, indexResultTypes, mesh) + .getResult(); + // halo sizes are provided for split dimensions only + auto currHaloDim = 0; + + for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) { + if (splitAxes.empty()) { + continue; + } + // Get the linearized ids of the neighbors (down and up) for the + // given split + auto tmp = rewriter + .create(loc, mesh, myMultiIndex, + splitAxes) + .getResults(); + // MPI operates on i32... + Value neighbourIDs[2] = {rewriter.create( + loc, rewriter.getI32Type(), tmp[0]), + rewriter.create( + loc, rewriter.getI32Type(), tmp[1])}; + // store for later + auto orgDimSize = shape[dim]; + // this dim's offset to the start of the upper halo + auto upperOffset = rewriter.create( + loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); + + // Make sure we send/recv in a way that does not lead to a dead-lock. + // The current approach is by far not optimal, this should be at least + // be a red-black pattern or using MPI_sendrecv. + // Also, buffers should be re-used. + // Still using temporary contiguous buffers for MPI communication... + // Still yielding a "serialized" communication pattern... + auto genSendRecv = [&](auto dim, bool upperHalo) { + auto orgOffset = offsets[dim]; + shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] + : haloSizes[currHaloDim * 2]; + // Check if we need to send and/or receive + // Processes on the mesh borders have only one neighbor + auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; + auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; + auto hasFrom = rewriter.create( + loc, arith::CmpIPredicate::sge, from, zero); + auto hasTo = rewriter.create( + loc, arith::CmpIPredicate::sge, to, zero); + auto buffer = rewriter.create( + loc, shape, array.getType().getElementType()); + // if has neighbor: copy halo data from array to buffer and send + rewriter.create( + loc, hasTo, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0)) + : OpFoldResult(upperOffset); + auto subview = builder.create( + loc, array, offsets, shape, strides); + builder.create(loc, subview, buffer); + builder.create(loc, TypeRange{}, buffer, tag, to); + builder.create(loc); + }); + // if has neighbor: receive halo data into buffer and copy to array + rewriter.create( + loc, hasFrom, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo ? OpFoldResult(upperOffset) + : OpFoldResult(builder.getIndexAttr(0)); + builder.create(loc, TypeRange{}, buffer, tag, from); + auto subview = builder.create( + loc, array, offsets, shape, strides); + builder.create(loc, buffer, subview); + builder.create(loc); + }); + rewriter.create(loc, buffer); + offsets[dim] = orgOffset; + }; + + genSendRecv(dim, false); + genSendRecv(dim, true); + + // prepare shape and offsets for next split dim + auto _haloSz = + rewriter + .create(loc, toValue(haloSizes[currHaloDim * 2]), + toValue(haloSizes[currHaloDim * 2 + 1])) + .getResult(); + // the shape for next halo excludes the halo on both ends for the + // current dim + shape[dim] = + rewriter.create(loc, toValue(orgDimSize), _haloSz) + .getResult(); + // the offsets for next halo starts after the down halo for the + // current dim + offsets[dim] = haloSizes[currHaloDim * 2]; + // on to next halo + ++currHaloDim; + } + rewriter.eraseOp(op); + return mlir::success(); + } +}; + struct ConvertMeshToMPIPass : public impl::ConvertMeshToMPIPassBase { using Base::Base; /// Run the dialect converter on the module. void runOnOperation() override { - getOperation()->walk([&](UpdateHaloOp op) { - SymbolTableCollection symbolTableCollection; - OpBuilder builder(op); - auto loc = op.getLoc(); - - auto toValue = [&builder, &loc](OpFoldResult &v) { - return v.is() - ? v.get() - : builder.create<::mlir::arith::ConstantOp>( - loc, - builder.getIndexAttr( - cast(v.get()).getInt())); - }; + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); - auto array = op.getInput(); - auto rank = array.getType().getRank(); - auto mesh = op.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); - auto haloSizes = getMixedValues(op.getStaticHaloSizes(), - op.getDynamicHaloSizes(), builder); - for (auto &sz : haloSizes) { - if (sz.is()) { - sz = builder - .create(loc, builder.getIndexType(), - sz.get()) - .getResult(); - } - } - - SmallVector offsets(rank, builder.getIndexAttr(0)); - SmallVector strides(rank, builder.getIndexAttr(1)); - SmallVector shape(rank); - for (auto [i, s] : llvm::enumerate(array.getType().getShape())) { - if (ShapedType::isDynamic(s)) { - shape[i] = builder.create(loc, array, s).getResult(); - } else { - shape[i] = builder.getIndexAttr(s); - } - } + patterns.insert(ctx); - auto tagAttr = builder.getI32IntegerAttr(91); // whatever - auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr); - auto zeroAttr = builder.getI32IntegerAttr(0); // whatever - auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr); - SmallVector indexResultTypes(meshOp.getShape().size(), - builder.getIndexType()); - auto myMultiIndex = - builder.create(loc, indexResultTypes, mesh) - .getResult(); - auto currHaloDim = 0; - - for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) { - if (!splitAxes.empty()) { - auto tmp = builder - .create( - loc, mesh, myMultiIndex, splitAxes) - .getResults(); - Value neighbourIDs[2] = {builder.create( - loc, builder.getI32Type(), tmp[0]), - builder.create( - loc, builder.getI32Type(), tmp[1])}; - auto orgDimSize = shape[dim]; - auto upperOffset = builder.create( - loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1])); - - // make sure we send/recv in a way that does not lead to a dead-lock - // This is by far not optimal, this should be at least MPI_sendrecv - // and - probably even more importantly - buffers should be re-used - // Currently using temporary, contiguous buffer for MPI communication - auto genSendRecv = [&](auto dim, bool upperHalo) { - auto orgOffset = offsets[dim]; - shape[dim] = - upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2]; - auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; - auto hasFrom = builder.create( - loc, arith::CmpIPredicate::sge, from, zero); - auto hasTo = builder.create( - loc, arith::CmpIPredicate::sge, to, zero); - auto buffer = builder.create( - loc, shape, array.getType().getElementType()); - builder.create( - loc, hasTo, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo - ? OpFoldResult(builder.getIndexAttr(0)) - : OpFoldResult(upperOffset); - auto subview = builder.create( - loc, array, offsets, shape, strides); - builder.create(loc, subview, buffer); - builder.create(loc, TypeRange{}, buffer, tag, - to); - builder.create(loc); - }); - builder.create( - loc, hasFrom, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo - ? OpFoldResult(upperOffset) - : OpFoldResult(builder.getIndexAttr(0)); - builder.create(loc, TypeRange{}, buffer, tag, - from); - auto subview = builder.create( - loc, array, offsets, shape, strides); - builder.create(loc, buffer, subview); - builder.create(loc); - }); - builder.create(loc, buffer); - offsets[dim] = orgOffset; - }; - - genSendRecv(dim, false); - genSendRecv(dim, true); - - shape[dim] = builder - .create( - loc, toValue(orgDimSize), - builder - .create( - loc, toValue(haloSizes[dim * 2]), - toValue(haloSizes[dim * 2 + 1])) - .getResult()) - .getResult(); - offsets[dim] = haloSizes[dim * 2]; - ++currHaloDim; - } - } - }); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)); } }; -} // namespace \ No newline at end of file + +} // namespace + +// Create a pass that convert Mesh to MPI +std::unique_ptr<::mlir::OperationPass> createConvertMeshToMPIPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index 9ef826ca0cdac..5f563364272d9 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -1,34 +1,173 @@ -// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s +// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s // CHECK: mesh.mesh @mesh0 mesh.mesh @mesh0(shape = 2x2x4) -// ----- - -// CHECK-LABEL: func @update_halo -func.func @update_halo_1d( - // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> +// CHECK-LABEL: func @update_halo_1d_first +func.func @update_halo_1d_first( + // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { - // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 - // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0 - // CHECK-SAME: split_axes = {{\[\[}}0]] - // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> - %c2 = arith.constant 2 : i64 + // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index + // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 + // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 + // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8> + // CHECK-NEXT: scf.if [[v3]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v2]] { + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8> + // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8> + // CHECK-NEXT: scf.if [[v5]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8> + // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v4]] { + // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8> + // CHECK-NEXT: return mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] - halo_sizes = [2, %c2] : memref<12x12xi8> + halo_sizes = [2, 3] : memref<12x12xi8> + return +} + +// CHECK-LABEL: func @update_halo_1d_second +func.func @update_halo_1d_second( + // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> + %arg0 : memref<12x12xi8>) { + // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index + // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 + // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 + // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8> + // CHECK-NEXT: scf.if [[v3]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v2]] { + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8> + // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8> + // CHECK-NEXT: scf.if [[v5]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8> + // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v4]] { + // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8> + // CHECK-NEXT: return + mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]] + halo_sizes = [2, 3] : memref<12x12xi8> return } +// CHECK-LABEL: func @update_halo_2d func.func @update_halo_2d( - // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> + // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { - %c2 = arith.constant 2 : i64 - // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0 - // CHECK-SAME: split_axes = {{\[\[}}0], [1]] - // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] - // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8> + // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index + // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index + // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 + // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 + // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8> + // CHECK-NEXT: scf.if [[v3]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v2]] { + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8> + // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8> + // CHECK-NEXT: scf.if [[v5]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8> + // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v4]] { + // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8> + // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index + // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32 + // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32 + // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref + // CHECK-NEXT: scf.if [[v9]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref> to memref + // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v8]] { + // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref to memref> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref + // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref + // CHECK-NEXT: scf.if [[v11]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref> to memref + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v10]] { + // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref to memref> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref + // CHECK-NEXT: return mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] - halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2] - : memref<12x12xi8> + halo_sizes = [1, 2, 3, 4] + : memref<12x12xi8> return } From aeee16c6e3454ab7c04168364d544ffda2e7f344 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 20 Aug 2024 19:23:13 +0200 Subject: [PATCH 03/15] fixed corner halos by reversing data-exchanges from high to low dims --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 91 +++++----- .../MeshToMPI/convert-mesh-to-mpi.mlir | 161 +++++++++--------- 2 files changed, 137 insertions(+), 115 deletions(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 42d885a109ee7..9cf9458ce2b68 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp auto array = op.getInput(); auto rank = array.getType().getRank(); + auto opSplitAxes = op.getSplitAxes().getAxes(); auto mesh = op.getMesh(); auto meshOp = getMesh(op, symbolTableCollection); auto haloSizes = getMixedValues(op.getStaticHaloSizes(), @@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp // most of the offset/size/stride data is the same for all dims SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector strides(rank, rewriter.getIndexAttr(1)); - SmallVector shape(rank); + SmallVector shape(rank), dimSizes(rank); + auto currHaloDim = -1; // halo sizes are provided for split dimensions only // we need the actual shape to compute offsets and sizes - for (auto [i, s] : llvm::enumerate(array.getType().getShape())) { + for (auto i = 0; i < rank; ++i) { + auto s = array.getType().getShape()[i]; if (ShapedType::isDynamic(s)) { shape[i] = rewriter.create(loc, array, s).getResult(); } else { shape[i] = rewriter.getIndexAttr(s); } + + if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { + ++currHaloDim; + // the offsets for lower dim sstarts after their down halo + offsets[i] = haloSizes[currHaloDim * 2]; + + // prepare shape and offsets of highest dim's halo exchange + auto _haloSz = + rewriter + .create(loc, toValue(haloSizes[currHaloDim * 2]), + toValue(haloSizes[currHaloDim * 2 + 1])) + .getResult(); + // the halo shape of lower dims exlude the halos + dimSizes[i] = + rewriter.create(loc, toValue(shape[i]), _haloSz) + .getResult(); + } else { + dimSizes[i] = shape[i]; + } } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr); + SmallVector indexResultTypes(meshOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = rewriter.create(loc, indexResultTypes, mesh) .getResult(); - // halo sizes are provided for split dimensions only - auto currHaloDim = 0; - - for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) { + // traverse all split axes from high to low dim + for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { + auto splitAxes = opSplitAxes[dim]; if (splitAxes.empty()) { continue; } + assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split auto tmp = rewriter @@ -124,11 +147,13 @@ struct ConvertUpdateHaloOp loc, rewriter.getI32Type(), tmp[0]), rewriter.create( loc, rewriter.getI32Type(), tmp[1])}; - // store for later - auto orgDimSize = shape[dim]; - // this dim's offset to the start of the upper halo - auto upperOffset = rewriter.create( + + auto lowerRecvOffset = rewriter.getIndexAttr(0); + auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); + auto upperRecvOffset = rewriter.create( loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); + auto upperSendOffset = rewriter.create( + loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); // Make sure we send/recv in a way that does not lead to a dead-lock. // The current approach is by far not optimal, this should be at least @@ -136,10 +161,10 @@ struct ConvertUpdateHaloOp // Also, buffers should be re-used. // Still using temporary contiguous buffers for MPI communication... // Still yielding a "serialized" communication pattern... - auto genSendRecv = [&](auto dim, bool upperHalo) { + auto genSendRecv = [&](bool upperHalo) { auto orgOffset = offsets[dim]; - shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] - : haloSizes[currHaloDim * 2]; + dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] + : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive // Processes on the mesh borders have only one neighbor auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; @@ -149,14 +174,14 @@ struct ConvertUpdateHaloOp auto hasTo = rewriter.create( loc, arith::CmpIPredicate::sge, to, zero); auto buffer = rewriter.create( - loc, shape, array.getType().getElementType()); + loc, dimSizes, array.getType().getElementType()); // if has neighbor: copy halo data from array to buffer and send rewriter.create( loc, hasTo, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0)) - : OpFoldResult(upperOffset); + offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) + : OpFoldResult(upperSendOffset); auto subview = builder.create( - loc, array, offsets, shape, strides); + loc, array, offsets, dimSizes, strides); builder.create(loc, subview, buffer); builder.create(loc, TypeRange{}, buffer, tag, to); builder.create(loc); @@ -164,11 +189,11 @@ struct ConvertUpdateHaloOp // if has neighbor: receive halo data into buffer and copy to array rewriter.create( loc, hasFrom, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo ? OpFoldResult(upperOffset) - : OpFoldResult(builder.getIndexAttr(0)); + offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) + : OpFoldResult(lowerRecvOffset); builder.create(loc, TypeRange{}, buffer, tag, from); auto subview = builder.create( - loc, array, offsets, shape, strides); + loc, array, offsets, dimSizes, strides); builder.create(loc, buffer, subview); builder.create(loc); }); @@ -176,25 +201,15 @@ struct ConvertUpdateHaloOp offsets[dim] = orgOffset; }; - genSendRecv(dim, false); - genSendRecv(dim, true); - - // prepare shape and offsets for next split dim - auto _haloSz = - rewriter - .create(loc, toValue(haloSizes[currHaloDim * 2]), - toValue(haloSizes[currHaloDim * 2 + 1])) - .getResult(); - // the shape for next halo excludes the halo on both ends for the - // current dim - shape[dim] = - rewriter.create(loc, toValue(orgDimSize), _haloSz) - .getResult(); - // the offsets for next halo starts after the down halo for the - // current dim - offsets[dim] = haloSizes[currHaloDim * 2]; + genSendRecv(false); + genSendRecv(true); + + // the shape for lower dims include higher dims' halos + dimSizes[dim] = shape[dim]; + // -> the offset for higher dims is always 0 + offsets[dim] = rewriter.getIndexAttr(0); // on to next halo - ++currHaloDim; + --currHaloDim; } rewriter.eraseOp(op); return mlir::success(); diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index 5f563364272d9..c3b0dc12e6d74 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -6,8 +6,10 @@ mesh.mesh @mesh0(shape = 2x2x4) // CHECK-LABEL: func @update_halo_1d_first func.func @update_halo_1d_first( // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) { + %arg0 : memref<12x12xi8>) { + // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index @@ -18,7 +20,7 @@ func.func @update_halo_1d_first( // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8> // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 // CHECK-NEXT: } @@ -32,8 +34,8 @@ func.func @update_halo_1d_first( // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8> // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8> // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32 // CHECK-NEXT: } // CHECK-NEXT: scf.if [[v4]] { @@ -42,9 +44,9 @@ func.func @update_halo_1d_first( // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> // CHECK-NEXT: } // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8> - // CHECK-NEXT: return mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<12x12xi8> + // CHECK-NEXT: return return } @@ -52,44 +54,46 @@ func.func @update_halo_1d_first( func.func @update_halo_1d_second( // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { - // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 - // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index - // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 - // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 - // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8> - // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v2]] { - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8> - // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8> - // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8> - // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v4]] { - // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8> - // CHECK-NEXT: return + //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index + //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index + //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index + //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index + //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 + //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 + //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8> + //CHECK-NEXT: scf.if [[v3]] { + //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>> + //CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8> + //CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32 + //CHECK-NEXT: } + //CHECK-NEXT: scf.if [[v2]] { + //CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32 + //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>> + //CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>> + //CHECK-NEXT: } + //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8> + //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8> + //CHECK-NEXT: scf.if [[v5]] { + //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> + //CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8> + //CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32 + //CHECK-NEXT: } + //CHECK-NEXT: scf.if [[v4]] { + //CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32 + //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> + //CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> + //CHECK-NEXT: } + //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8> mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]] halo_sizes = [2, 3] : memref<12x12xi8> + //CHECK-NEXT: return return } @@ -97,77 +101,80 @@ func.func @update_halo_1d_second( func.func @update_halo_2d( // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { + // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index + // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index + // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index + // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index - // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index + // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8> + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref> to memref + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref, i32, i32 // CHECK-NEXT: } // CHECK-NEXT: scf.if [[v2]] { - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>> + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref to memref> // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8> + // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8> - // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref> to memref + // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref, i32, i32 // CHECK-NEXT: } // CHECK-NEXT: scf.if [[v4]] { - // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> + // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref to memref> // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8> - // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index + // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref + // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32 // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32 // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref + // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8> // CHECK-NEXT: scf.if [[v9]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref> to memref - // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8> + // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32 // CHECK-NEXT: } // CHECK-NEXT: scf.if [[v8]] { - // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref to memref> + // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>> // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref + // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8> // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref + // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8> // CHECK-NEXT: scf.if [[v11]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref> to memref - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32 // CHECK-NEXT: } // CHECK-NEXT: scf.if [[v10]] { - // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref to memref> + // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref - // CHECK-NEXT: return + // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8> mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : memref<12x12xi8> + // CHECK-NEXT: return return } From 6e967fb9f51b542af7b5244eb609875e47433cd8 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 21 Aug 2024 12:25:08 +0200 Subject: [PATCH 04/15] addressed review comments (docs, formatting) --- .../mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +- mlir/include/mlir/Conversion/Passes.td | 2 +- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 6 +++--- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 16 +++++++++------- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h index 6a2c196da4557..b8803f386f735 100644 --- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h +++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h @@ -1,4 +1,4 @@ -//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===// +//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 83e0c5a06c43f..2781fab917048 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -886,7 +886,7 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { let summary = "Convert Mesh dialect to MPI dialect."; let description = [{ This pass converts communication operations - from the Mesh dialect to operations from the MPI dialect. + from the Mesh dialect to the MPI dialect. }]; let dependentDialects = [ "memref::MemRefDialect", diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 2c2b6e20f3654..e6f61aa84a131 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -162,7 +162,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ DeclareOpInterfaceMethods ]> { let summary = - "For given split axes get the linear index the direct neighbor processes."; + "For given split axes get the linear indices of the direct neighbor processes."; let description = [{ Example: ``` @@ -172,8 +172,8 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ Given `@mesh` with shape `(10, 20, 30)`, `device` = `(1, 2, 3)` `$split_axes` = `[1]` - it returns the linear indices of the processes at positions `(1, 1, 3)`: `633` - and `(1, 3, 3)`: `693`. + returns two indices, `633` and `693`, which correspond to the index of the previous + process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`. A negative value is returned if `$device` has no neighbor in the given direction along the given `split_axes`. diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 9cf9458ce2b68..ea1323e43462c 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -45,15 +45,17 @@ struct ConvertUpdateHaloOp mlir::LogicalResult matchAndRewrite(mlir::mesh::UpdateHaloOp op, mlir::PatternRewriter &rewriter) const override { + // The input/output memref is assumed to be in C memory order. // Halos are exchanged as 2 blocks per dimension (one for each side: down - // and up). It is assumed that the last dim in a default memref is - // contiguous, hence iteration starts with the complete halo on the first - // dim which should be contiguous (unless the source is not). The size of - // the exchanged data will decrease when iterating over dimensions. That's - // good because the halos of last dim will be most fragmented. + // and up). For each haloed dimension `d`, the exchanged blocks are + // expressed as multi-dimensional subviews. The subviews include potential + // halos of higher dimensions `dh > d`, no halos for the lower dimensions + // `dl < d` and for dimension `d` the currently exchanged halo only. + // By iterating form higher to lower dimensions this also updates the halos + // in the 'corners'. // memref.subview is used to read and write the halo data from and to the - // local data. subviews and halos have dynamic and static values, so - // OpFoldResults are used whenever possible. + // local data. Because subviews and halos can have mixed dynamic and static + // shapes, OpFoldResults are used whenever possible. SymbolTableCollection symbolTableCollection; auto loc = op.getLoc(); From a63cfa3446a065001a34512e31998a51f738ca39 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 3 Sep 2024 10:41:22 +0200 Subject: [PATCH 05/15] newline --- mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h index b8803f386f735..04271f8ab67b9 100644 --- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h +++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h @@ -24,4 +24,4 @@ std::unique_ptr createConvertMeshToMPIPass(); } // namespace mlir -#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H \ No newline at end of file +#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H From 38c21af59efdd57b31f8d8daf5b02f84c4d83dbc Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 31 Oct 2024 17:54:52 +0100 Subject: [PATCH 06/15] removing source from UpdateHaloOp, because not required for destination passing style --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 19 +++++++------------ mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 10 +++++----- .../Dialect/Mesh/Transforms/Spmdization.cpp | 4 +--- mlir/test/Dialect/Mesh/ops.mlir | 16 ++++++++-------- mlir/test/Dialect/Mesh/spmdization.mlir | 4 ++-- 5 files changed, 23 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index e6f61aa84a131..3c52c63330e95 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -1095,8 +1095,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ TypesMatchWith< "result has same type as destination", "result", "destination", "$_self">, - DeclareOpInterfaceMethods, - AttrSizedOperandSegments + DeclareOpInterfaceMethods ]> { let summary = "Update halo data."; let description = [{ @@ -1105,7 +1104,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones. - Source and destination might have different halo sizes. + Destination is supposed to be initialized with the local data (not halos). Assumes all devices hold tensors with same-sized halo data as specified by `source_halo_sizes/static_source_halo_sizes` and @@ -1117,25 +1116,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ }]; let arguments = (ins - AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source, AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination, FlatSymbolRefAttr:$mesh, Mesh_MeshAxesArrayAttr:$split_axes, - Variadic:$source_halo_sizes, - DefaultValuedAttr:$static_source_halo_sizes, - Variadic:$destination_halo_sizes, - DefaultValuedAttr:$static_destination_halo_sizes + Variadic:$halo_sizes, + DefaultValuedAttr:$static_halo_sizes ); let results = (outs AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result ); let assemblyFormat = [{ - $source `into` $destination + $destination `on` $mesh `split_axes` `=` $split_axes - (`source_halo_sizes` `=` custom($source_halo_sizes, $static_source_halo_sizes)^)? - (`destination_halo_sizes` `=` custom($destination_halo_sizes, $static_destination_halo_sizes)^)? - attr-dict `:` type($source) `->` type($result) + (`halo_sizes` `=` custom($halo_sizes, $static_halo_sizes)^)? + attr-dict `:` type($result) }]; let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); } diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index ea1323e43462c..11d7c0e08f1a6 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -70,13 +70,13 @@ struct ConvertUpdateHaloOp cast(v.get()).getInt())); }; - auto array = op.getInput(); - auto rank = array.getType().getRank(); + auto array = op.getDestination(); + auto rank = cast(array.getType()).getRank(); auto opSplitAxes = op.getSplitAxes().getAxes(); auto mesh = op.getMesh(); auto meshOp = getMesh(op, symbolTableCollection); auto haloSizes = getMixedValues(op.getStaticHaloSizes(), - op.getDynamicHaloSizes(), rewriter); + op.getHaloSizes(), rewriter); // subviews need Index values for (auto &sz : haloSizes) { if (sz.is()) { @@ -94,7 +94,7 @@ struct ConvertUpdateHaloOp auto currHaloDim = -1; // halo sizes are provided for split dimensions only // we need the actual shape to compute offsets and sizes for (auto i = 0; i < rank; ++i) { - auto s = array.getType().getShape()[i]; + auto s = cast(array.getType()).getShape()[i]; if (ShapedType::isDynamic(s)) { shape[i] = rewriter.create(loc, array, s).getResult(); } else { @@ -176,7 +176,7 @@ struct ConvertUpdateHaloOp auto hasTo = rewriter.create( loc, arith::CmpIPredicate::sge, to, zero); auto buffer = rewriter.create( - loc, dimSizes, array.getType().getElementType()); + loc, dimSizes, cast(array.getType()).getElementType()); // if has neighbor: copy halo data from array to buffer and send rewriter.create( loc, hasTo, [&](OpBuilder &builder, Location loc) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index b4d088cbd7088..327ea0991e4e1 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -496,11 +496,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, sourceShard.getLoc(), RankedTensorType::get(outShape, sourceShard.getType().getElementType()), - sourceShard, initOprnd, mesh.getSymName(), + initOprnd, mesh.getSymName(), MeshAxesArrayAttr::get(builder.getContext(), sourceSharding.getSplitAxes()), - sourceSharding.getDynamicHaloSizes(), - sourceSharding.getStaticHaloSizes(), targetSharding.getDynamicHaloSizes(), targetSharding.getStaticHaloSizes()) .getResult(); diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index d8df01c3d6520..978de4939ee77 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -615,16 +615,16 @@ func.func @update_halo( // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 - // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0 + // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0 // CHECK-SAME: split_axes = {{\[\[}}0]] - // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8> + // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> %c2 = arith.constant 2 : i64 - %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]] - source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8> - // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0 + %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + halo_sizes = [2, %c2] : memref<12x12xi8> + // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0 // CHECK-SAME: split_axes = {{\[\[}}0], [1]] - // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8> - %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]] - source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8> + // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> + %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]] + halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> return } diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 22ddb72569835..c1b96fda0f4a7 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64> + // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> @@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200 %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64> + // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> From 919498bdfdadd031449dc334e2d75dcd6794b514 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 5 Nov 2024 16:55:08 +0100 Subject: [PATCH 07/15] clang-format --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 11d7c0e08f1a6..5d9ea9cfccf8d 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -75,8 +75,8 @@ struct ConvertUpdateHaloOp auto opSplitAxes = op.getSplitAxes().getAxes(); auto mesh = op.getMesh(); auto meshOp = getMesh(op, symbolTableCollection); - auto haloSizes = getMixedValues(op.getStaticHaloSizes(), - op.getHaloSizes(), rewriter); + auto haloSizes = + getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter); // subviews need Index values for (auto &sz : haloSizes) { if (sz.is()) { From 60de21f6a063637db4d5474b42afde078a7e05fc Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 5 Nov 2024 19:50:56 +0100 Subject: [PATCH 08/15] allow tensor as destination in UpdateHaloOp and fixing its tests --- mlir/include/mlir/Conversion/Passes.td | 3 +- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 23 ++++++- .../MeshToMPI/convert-mesh-to-mpi.mlir | 65 ++++++++++++++++--- 3 files changed, 79 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2781fab917048..43015ad5b11e6 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -891,7 +891,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { let dependentDialects = [ "memref::MemRefDialect", "mpi::MPIDialect", - "scf::SCFDialect" + "scf::SCFDialect", + "bufferization::BufferizationDialect" ]; } diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 5d9ea9cfccf8d..b1b58584aaae2 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" @@ -70,7 +71,16 @@ struct ConvertUpdateHaloOp cast(v.get()).getInt())); }; - auto array = op.getDestination(); + auto dest = op.getDestination(); + auto dstShape = cast(dest.getType()).getShape(); + Value array = dest; + if (isa(array.getType())) { + // If the destination is a memref, we need to cast it to a tensor + auto tensorType = MemRefType::get( + dstShape, cast(array.getType()).getElementType()); + array = rewriter.create(loc, tensorType, array) + .getResult(); + } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = op.getSplitAxes().getAxes(); auto mesh = op.getMesh(); @@ -94,7 +104,7 @@ struct ConvertUpdateHaloOp auto currHaloDim = -1; // halo sizes are provided for split dimensions only // we need the actual shape to compute offsets and sizes for (auto i = 0; i < rank; ++i) { - auto s = cast(array.getType()).getShape()[i]; + auto s = dstShape[i]; if (ShapedType::isDynamic(s)) { shape[i] = rewriter.create(loc, array, s).getResult(); } else { @@ -213,7 +223,14 @@ struct ConvertUpdateHaloOp // on to next halo --currHaloDim; } - rewriter.eraseOp(op); + + if (isa(op.getResult().getType())) { + rewriter.replaceOp(op, array); + } else { + assert(isa(op.getResult().getType())); + rewriter.replaceOp(op, rewriter.create( + loc, op.getResult().getType(), array)); + } return mlir::success(); } }; diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index c3b0dc12e6d74..d05c53bd83aaf 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -53,7 +53,7 @@ func.func @update_halo_1d_first( // CHECK-LABEL: func @update_halo_1d_second func.func @update_halo_1d_second( // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) { + %arg0 : memref<12x12xi8>) -> memref<12x12xi8> { //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index @@ -91,16 +91,16 @@ func.func @update_halo_1d_second( //CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> //CHECK-NEXT: } //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8> - mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]] + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]] halo_sizes = [2, 3] : memref<12x12xi8> - //CHECK-NEXT: return - return + //CHECK-NEXT: return [[varg0]] : memref<12x12xi8> + return %res : memref<12x12xi8> } // CHECK-LABEL: func @update_halo_2d func.func @update_halo_2d( // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) { + %arg0 : memref<12x12xi8>) -> memref<12x12xi8> { // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index @@ -172,9 +172,58 @@ func.func @update_halo_2d( // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> // CHECK-NEXT: } // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8> - mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : memref<12x12xi8> - // CHECK-NEXT: return - return + // CHECK-NEXT: return [[varg0]] : memref<12x12xi8> + return %res : memref<12x12xi8> +} + +// CHECK-LABEL: func @update_halo_1d_tnsr +func.func @update_halo_1d_tnsr( + // CHECK-SAME: [[varg0:%.*]]: tensor<12x12xi8> + %arg0 : tensor<12x12xi8>) -> tensor<12x12xi8> { + // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index + // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index + // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index + // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[mref:%.*]] = bufferization.to_memref %arg0 : memref<12x12xi8> + // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index + // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 + // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 + // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8> + // CHECK-NEXT: scf.if [[v3]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v2]] { + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8> + // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 + // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8> + // CHECK-NEXT: scf.if [[v5]] { + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8> + // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.if [[v4]] { + // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> + // CHECK-NEXT: } + // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8> + // CHECK-NEXT: [[res:%.*]] = bufferization.to_tensor [[mref]] : memref<12x12xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + halo_sizes = [2, 3] : tensor<12x12xi8> + // CHECK-NEXT: return [[res]] + return %res : tensor<12x12xi8> } From a90787028263f632b63ce7f277e59c86c6bd890b Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 7 Nov 2024 13:17:00 +0100 Subject: [PATCH 09/15] converting LinearIndex, MultiIndex and NeighborsIndex to MPI --- .../mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +- mlir/include/mlir/Conversion/Passes.td | 1 + mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 128 +++++++++++++++++- 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h index 04271f8ab67b9..44a1cc0adb6a0 100644 --- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h +++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h @@ -20,7 +20,7 @@ class Pass; /// Lowers Mesh communication operations (updateHalo, AllGater, ...) /// to MPI primitives. -std::unique_ptr createConvertMeshToMPIPass(); +std::unique_ptr<::mlir::Pass> createConvertMeshToMPIPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 43015ad5b11e6..15fc13f5e12d8 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -884,6 +884,7 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> { def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { let summary = "Convert Mesh dialect to MPI dialect."; + let constructor = "mlir::createConvertMeshToMPIPass()"; let description = [{ This pass converts communication operations from the Mesh dialect to the MPI dialect. diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index b1b58584aaae2..0f0cc28ca363a 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -37,6 +37,130 @@ using namespace mlir; using namespace mlir::mesh; namespace { +static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) { + int n = dimensions.size(); + SmallVector multiIndex(n); + + for (int i = n - 1; i >= 0; --i) { + b.create(loc, linearIndex, dimensions[i]); + multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); + if(i > 0) { + linearIndex = b.create(loc, linearIndex, dimensions[i]); + } + } + + return multiIndex; +} + +Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { + auto linearIndex = b.create(loc, 0).getResult(); + auto stride = b.create(loc, 1).getResult(); + + for (int i = multiIndex.size() - 1; i >= 0; --i) { + linearIndex = b.create(loc, multiIndex[i], stride); + stride = b.create(loc, stride, dimensions[i]); + } + + return linearIndex; +} + +// This pattern converts the mesh.update_halo operation to MPI calls +struct ConvertProcessMultiIndexOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, + mlir::PatternRewriter &rewriter) const override { + SymbolTableCollection symbolTableCollection; + auto loc = op.getLoc(); + auto meshOp = getMesh(op, symbolTableCollection); + // For now we only support static mesh shapes + if(ShapedType::isDynamicShape(meshOp.getShape())) { + return mlir::failure(); + } + + SmallVector dims; + llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); + auto rank = rewriter.create(op.getLoc(), meshOp).getResult(); + auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); + + // optionally extract subset of mesh axes + auto axes = op.getAxes(); + if(!axes.empty()) { + SmallVector subIndex; + for(auto axis : axes) { + subIndex.push_back(mIdx[axis]); + } + mIdx = subIndex; + } + + rewriter.replaceOp(op, mIdx); + return mlir::success(); + } +}; + +// This pattern converts the mesh.update_halo operation to MPI calls +struct ConvertProcessLinearIndexOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, + mlir::PatternRewriter &rewriter) const override { + auto rank = rewriter.create(op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}).getRank(); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), rank); + return mlir::success(); + } +}; + +// This pattern converts the mesh.update_halo operation to MPI calls +struct ConvertNeighborsLinearIndicesOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, + mlir::PatternRewriter &rewriter) const override { + auto axes = op.getSplitAxes(); + // For now only single axis sharding is supported + if(axes.size() != 1) { + return mlir::failure(); + } + + auto loc = op.getLoc(); + SymbolTableCollection symbolTableCollection; + auto meshOp = getMesh(op, symbolTableCollection); + auto mIdx = op.getDevice(); + auto orgIdx = mIdx[axes[0]]; + SmallVector dims; + llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); + auto dimSz = dims[axes[0]]; + auto minus1 = rewriter.create(loc, -1).getResult(); + auto atBorder = rewriter.create(loc, arith::CmpIPredicate::sle, dimSz, rewriter.create(loc, 0).getResult()); + auto down = rewriter.create( + loc, atBorder, [&](OpBuilder &builder, Location loc) { + builder.create(loc, minus1); + }, [&](OpBuilder &builder, Location loc) { + mIdx[axes[0]] = rewriter.create(op.getLoc(), orgIdx, dimSz).getResult(); + builder.create(loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + }); + atBorder = rewriter.create(loc, arith::CmpIPredicate::sge, dimSz, rewriter.create(loc, dimSz, minus1).getResult()); + auto up = rewriter.create( + loc, atBorder, [&](OpBuilder &builder, Location loc) { + builder.create(loc, minus1); + }, [&](OpBuilder &builder, Location loc) { + mIdx[axes[0]] = rewriter.create(op.getLoc(), orgIdx, dimSz).getResult(); + builder.create(loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + }); + rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); + return mlir::success(); + } +}; // This pattern converts the mesh.update_halo operation to MPI calls struct ConvertUpdateHaloOp @@ -244,7 +368,7 @@ struct ConvertMeshToMPIPass auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert(ctx); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); @@ -254,6 +378,6 @@ struct ConvertMeshToMPIPass } // namespace // Create a pass that convert Mesh to MPI -std::unique_ptr<::mlir::OperationPass> createConvertMeshToMPIPass() { +std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() { return std::make_unique(); } From 7c3eddcac2c9cb9e84f2d59406d759174c4fa2bb Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 7 Nov 2024 16:57:50 +0100 Subject: [PATCH 10/15] allow constant shape propagation & fusion thoguh static_mpi_rank --- mlir/include/mlir/Conversion/Passes.td | 8 +- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 1 + mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 106 +++++++++++++------ 3 files changed, 83 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 15fc13f5e12d8..4d6be8d18d1fe 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -886,8 +886,12 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { let summary = "Convert Mesh dialect to MPI dialect."; let constructor = "mlir::createConvertMeshToMPIPass()"; let description = [{ - This pass converts communication operations - from the Mesh dialect to the MPI dialect. + This pass converts communication operations from the Mesh dialect to the + MPI dialect. + If it finds a global named "static_mpi_rank" it will use that splat value + instead of calling MPI_Comm_rank. This allows optimizations like constant + shape propagation and fusion because shard/partition sizes depend on the + rank. }]; let dependentDialects = [ "memref::MemRefDialect", diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 3c52c63330e95..726c92d6ec469 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -1091,6 +1091,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ } def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ + Pure, DestinationStyleOpInterface, TypesMatchWith< "result has same type as destination", diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 0f0cc28ca363a..f20068c9a43df 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -18,11 +18,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "mesh-to-mpi" @@ -37,14 +39,17 @@ using namespace mlir; using namespace mlir::mesh; namespace { -static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) { +// Create operations converting a linear index to a multi-dimensional index +static SmallVector linearToMultiIndex(Location loc, OpBuilder b, + Value linearIndex, + ValueRange dimensions) { int n = dimensions.size(); SmallVector multiIndex(n); for (int i = n - 1; i >= 0; --i) { b.create(loc, linearIndex, dimensions[i]); multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); - if(i > 0) { + if (i > 0) { linearIndex = b.create(loc, linearIndex, dimensions[i]); } } @@ -52,13 +57,16 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value li return multiIndex; } -Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { +// Create operations converting a multi-dimensional index to a linear index +Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, + ValueRange dimensions) { + auto linearIndex = b.create(loc, 0).getResult(); auto stride = b.create(loc, 1).getResult(); for (int i = multiIndex.size() - 1; i >= 0; --i) { - linearIndex = b.create(loc, multiIndex[i], stride); - stride = b.create(loc, stride, dimensions[i]); + linearIndex = b.create(loc, multiIndex[i], stride); + stride = b.create(loc, stride, dimensions[i]); } return linearIndex; @@ -76,22 +84,24 @@ struct ConvertProcessMultiIndexOp auto loc = op.getLoc(); auto meshOp = getMesh(op, symbolTableCollection); // For now we only support static mesh shapes - if(ShapedType::isDynamicShape(meshOp.getShape())) { + if (ShapedType::isDynamicShape(meshOp.getShape())) { return mlir::failure(); } SmallVector dims; - llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); - }); - auto rank = rewriter.create(op.getLoc(), meshOp).getResult(); + llvm::transform( + meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); + auto rank = + rewriter.create(op.getLoc(), meshOp).getResult(); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); // optionally extract subset of mesh axes auto axes = op.getAxes(); - if(!axes.empty()) { + if (!axes.empty()) { SmallVector subIndex; - for(auto axis : axes) { + for (auto axis : axes) { subIndex.push_back(mIdx[axis]); } mIdx = subIndex; @@ -102,7 +112,9 @@ struct ConvertProcessMultiIndexOp } }; -// This pattern converts the mesh.update_halo operation to MPI calls +// This pattern converts the mesh.update_halo operation to MPI calls. +// If it finds a global named "static_mpi_rank" it will use that splat value. +// Otherwise it defaults to mpi.comm_rank. struct ConvertProcessLinearIndexOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -110,8 +122,25 @@ struct ConvertProcessLinearIndexOp mlir::LogicalResult matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, mlir::PatternRewriter &rewriter) const override { - auto rank = rewriter.create(op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}).getRank(); - rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), rank); + auto loc = op.getLoc(); + auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank"); + if (auto globalOp = SymbolTable::lookupNearestSymbolFrom( + op, rankOpName)) { + if (auto initTnsr = globalOp.getInitialValueAttr()) { + auto val = cast(initTnsr).getSplatValue(); + rewriter.replaceOp(op, + rewriter.create(loc, val)); + return mlir::success(); + } + } + auto rank = + rewriter + .create( + op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), + rewriter.getI32Type()}) + .getRank(); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + rank); return mlir::success(); } }; @@ -126,7 +155,7 @@ struct ConvertNeighborsLinearIndicesOp mlir::PatternRewriter &rewriter) const override { auto axes = op.getSplitAxes(); // For now only single axis sharding is supported - if(axes.size() != 1) { + if (axes.size() != 1) { return mlir::failure(); } @@ -136,26 +165,41 @@ struct ConvertNeighborsLinearIndicesOp auto mIdx = op.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector dims; - llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); - }); + llvm::transform( + meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); auto dimSz = dims[axes[0]]; auto minus1 = rewriter.create(loc, -1).getResult(); - auto atBorder = rewriter.create(loc, arith::CmpIPredicate::sle, dimSz, rewriter.create(loc, 0).getResult()); + auto atBorder = rewriter.create( + loc, arith::CmpIPredicate::sle, dimSz, + rewriter.create(loc, 0).getResult()); auto down = rewriter.create( - loc, atBorder, [&](OpBuilder &builder, Location loc) { + loc, atBorder, + [&](OpBuilder &builder, Location loc) { builder.create(loc, minus1); - }, [&](OpBuilder &builder, Location loc) { - mIdx[axes[0]] = rewriter.create(op.getLoc(), orgIdx, dimSz).getResult(); - builder.create(loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + }, + [&](OpBuilder &builder, Location loc) { + mIdx[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, dimSz) + .getResult(); + builder.create( + loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); }); - atBorder = rewriter.create(loc, arith::CmpIPredicate::sge, dimSz, rewriter.create(loc, dimSz, minus1).getResult()); + atBorder = rewriter.create( + loc, arith::CmpIPredicate::sge, dimSz, + rewriter.create(loc, dimSz, minus1).getResult()); auto up = rewriter.create( - loc, atBorder, [&](OpBuilder &builder, Location loc) { + loc, atBorder, + [&](OpBuilder &builder, Location loc) { builder.create(loc, minus1); - }, [&](OpBuilder &builder, Location loc) { - mIdx[axes[0]] = rewriter.create(op.getLoc(), orgIdx, dimSz).getResult(); - builder.create(loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + }, + [&](OpBuilder &builder, Location loc) { + mIdx[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, dimSz) + .getResult(); + builder.create( + loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); return mlir::success(); @@ -368,7 +412,9 @@ struct ConvertMeshToMPIPass auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert( + ctx); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); From f0695fd50ded1d204a981b87fbf8e3f2ae7081f5 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Thu, 7 Nov 2024 19:08:02 +0100 Subject: [PATCH 11/15] fixing sned/recv border check --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index f20068c9a43df..ee7c77b35b285 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -47,7 +47,6 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, SmallVector multiIndex(n); for (int i = n - 1; i >= 0; --i) { - b.create(loc, linearIndex, dimensions[i]); multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); if (i > 0) { linearIndex = b.create(loc, linearIndex, dimensions[i]); @@ -172,7 +171,7 @@ struct ConvertNeighborsLinearIndicesOp auto dimSz = dims[axes[0]]; auto minus1 = rewriter.create(loc, -1).getResult(); auto atBorder = rewriter.create( - loc, arith::CmpIPredicate::sle, dimSz, + loc, arith::CmpIPredicate::sle, orgIdx, rewriter.create(loc, 0).getResult()); auto down = rewriter.create( loc, atBorder, @@ -187,7 +186,7 @@ struct ConvertNeighborsLinearIndicesOp loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); }); atBorder = rewriter.create( - loc, arith::CmpIPredicate::sge, dimSz, + loc, arith::CmpIPredicate::sge, orgIdx, rewriter.create(loc, dimSz, minus1).getResult()); auto up = rewriter.create( loc, atBorder, From 725c7343d1ce0af1e165665fbc8dab4c2162ad8b Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Fri, 8 Nov 2024 12:38:14 +0100 Subject: [PATCH 12/15] fixes and tests --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 19 +- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 20 +- .../MeshToMPI/convert-mesh-to-mpi.mlir | 421 +++++++++--------- 3 files changed, 226 insertions(+), 234 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 726c92d6ec469..6039e61a93fad 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -162,20 +162,21 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ DeclareOpInterfaceMethods ]> { let summary = - "For given split axes get the linear indices of the direct neighbor processes."; + "For given mesh index get the linear indices of the direct neighbor processes along the given split."; let description = [{ Example: ``` - %idx = mesh.neighbor_linear_index on @mesh for $device - split_axes = $split_axes : index + mesh.mesh @mesh0(shape = 10x20x30) + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index ``` - Given `@mesh` with shape `(10, 20, 30)`, - `device` = `(1, 2, 3)` - `$split_axes` = `[1]` - returns two indices, `633` and `693`, which correspond to the index of the previous - process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`. + The above returns two indices, `633` and `693`, which correspond to the + index of the previous process `(1, 1, 3)`, and the next process + `(1, 3, 3) along the split axis `1`. - A negative value is returned if `$device` has no neighbor in the given + A negative value is returned if there is no neighbor in the respective direction along the given `split_axes`. }]; let arguments = (ins FlatSymbolRefAttr:$mesh, diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index ee7c77b35b285..c51c5335fc609 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -64,7 +64,8 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, auto stride = b.create(loc, 1).getResult(); for (int i = multiIndex.size() - 1; i >= 0; --i) { - linearIndex = b.create(loc, multiIndex[i], stride); + auto off = b.create(loc, multiIndex[i], stride); + linearIndex = b.create(loc, linearIndex, off); stride = b.create(loc, stride, dimensions[i]); } @@ -169,6 +170,7 @@ struct ConvertNeighborsLinearIndicesOp return rewriter.create(loc, i).getResult(); }); auto dimSz = dims[axes[0]]; + auto one = rewriter.create(loc, 1).getResult(); auto minus1 = rewriter.create(loc, -1).getResult(); auto atBorder = rewriter.create( loc, arith::CmpIPredicate::sle, orgIdx, @@ -179,26 +181,28 @@ struct ConvertNeighborsLinearIndicesOp builder.create(loc, minus1); }, [&](OpBuilder &builder, Location loc) { - mIdx[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, dimSz) + SmallVector tmp = mIdx; + tmp[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, one) .getResult(); builder.create( - loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); atBorder = rewriter.create( loc, arith::CmpIPredicate::sge, orgIdx, - rewriter.create(loc, dimSz, minus1).getResult()); + rewriter.create(loc, dimSz, one).getResult()); auto up = rewriter.create( loc, atBorder, [&](OpBuilder &builder, Location loc) { builder.create(loc, minus1); }, [&](OpBuilder &builder, Location loc) { - mIdx[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, dimSz) + SmallVector tmp = mIdx; + tmp[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, one) .getResult(); builder.create( - loc, multiToLinearIndex(loc, rewriter, mIdx, dims)); + loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); return mlir::success(); diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index d05c53bd83aaf..38b7a12daef52 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -1,229 +1,216 @@ -// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s +// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s +// ----- // CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 2x2x4) +mesh.mesh @mesh0(shape = 3x4x5) +func.func @process_multi_index() -> (index, index, index) { + // CHECK: mpi.comm_rank : !mpi.retval, i32 + // CHECK-DAG: %[[v4:.*]] = arith.remsi + // CHECK-DAG: %[[v0:.*]] = arith.remsi + // CHECK-DAG: %[[v1:.*]] = arith.remsi + %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} -// CHECK-LABEL: func @update_halo_1d_first -func.func @update_halo_1d_first( - // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) { - // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index - // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index - // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 - // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index - // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 - // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 - // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8> - // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v2]] { - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8> - // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8> - // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8> - // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v4]] { - // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8> - mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] - halo_sizes = [2, 3] : memref<12x12xi8> - // CHECK-NEXT: return - return +// CHECK-LABEL: func @process_linear_index +func.func @process_linear_index() -> index { + // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32 + // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index + %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: return %[[cast]] : index + return %0 : index } -// CHECK-LABEL: func @update_halo_1d_second -func.func @update_halo_1d_second( - // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) -> memref<12x12xi8> { - //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index - //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index - //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index - //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 - //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index - //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 - //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 - //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8> - //CHECK-NEXT: scf.if [[v3]] { - //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>> - //CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8> - //CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32 - //CHECK-NEXT: } - //CHECK-NEXT: scf.if [[v2]] { - //CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32 - //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>> - //CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>> - //CHECK-NEXT: } - //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8> - //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8> - //CHECK-NEXT: scf.if [[v5]] { - //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> - //CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8> - //CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32 - //CHECK-NEXT: } - //CHECK-NEXT: scf.if [[v4]] { - //CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32 - //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> - //CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>> - //CHECK-NEXT: } - //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]] - halo_sizes = [2, 3] : memref<12x12xi8> - //CHECK-NEXT: return [[varg0]] : memref<12x12xi8> - return %res : memref<12x12xi8> +// CHECK-LABEL: func @neighbors_dim0 +func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index + // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// CHECK-LABEL: func @neighbors_dim1 +func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index + // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// CHECK-LABEL: func @neighbors_dim2 +func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index + // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// ----- +// CHECK: mesh.mesh @mesh0 +mesh.mesh @mesh0(shape = 3x4x5) +memref.global constant @static_mpi_rank : memref = dense<24> +func.func @process_multi_index() -> (index, index, index) { + // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @process_linear_index +func.func @process_linear_index() -> index { + // CHECK: %[[c24:.*]] = arith.constant 24 : index + %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: return %[[c24]] : index + return %0 : index +} + +// ----- +mesh.mesh @mesh0(shape = 3x4x5) +// CHECK-LABEL: func @update_halo_1d_first +func.func @update_halo_1d_first( + // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK: mpi.send( + // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 + // CHECK: mpi.recv( + // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 + // CHECK: mpi.send( + // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 + // CHECK: mpi.recv( + // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> + // CHECK: return [[res:%.*]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> } -// CHECK-LABEL: func @update_halo_2d -func.func @update_halo_2d( - // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8> - %arg0 : memref<12x12xi8>) -> memref<12x12xi8> { - // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index - // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index - // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index - // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index - // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index - // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index +// ----- +mesh.mesh @mesh0(shape = 3x4x5) +memref.global constant @static_mpi_rank : memref = dense<24> +// CHECK-LABEL: func @update_halo_3d +func.func @update_halo_3d( + // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 - // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index - // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 - // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 - // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref - // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref> to memref - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v2]] { - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref to memref> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc]] : memref - // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref - // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref> to memref - // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v4]] { - // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref> - // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref to memref> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref - // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index - // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32 - // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32 - // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8> - // CHECK-NEXT: scf.if [[v9]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8> - // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v8]] { - // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8> - // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8> - // CHECK-NEXT: scf.if [[v11]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v10]] { - // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]] - halo_sizes = [1, 2, 3, 4] - : memref<12x12xi8> - // CHECK-NEXT: return [[varg0]] : memref<12x12xi8> - return %res : memref<12x12xi8> + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref + // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref + // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> + // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> + // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> + // CHECK: return [[varg0]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> } -// CHECK-LABEL: func @update_halo_1d_tnsr -func.func @update_halo_1d_tnsr( - // CHECK-SAME: [[varg0:%.*]]: tensor<12x12xi8> - %arg0 : tensor<12x12xi8>) -> tensor<12x12xi8> { - // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index - // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index - // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index - // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 +// CHECK-LABEL: func @update_halo_3d_tensor +func.func @update_halo_3d_tensor( + // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> + %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { + // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[mref:%.*]] = bufferization.to_memref %arg0 : memref<12x12xi8> - // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index - // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32 - // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32 - // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8> - // CHECK-NEXT: scf.if [[v3]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v2]] { - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8> - // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32 - // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8> - // CHECK-NEXT: scf.if [[v5]] { - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8> - // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32 - // CHECK-NEXT: } - // CHECK-NEXT: scf.if [[v4]] { - // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32 - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[mref]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>> - // CHECK-NEXT: } - // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8> - // CHECK-NEXT: [[res:%.*]] = bufferization.to_tensor [[mref]] : memref<12x12xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] - halo_sizes = [2, 3] : tensor<12x12xi8> - // CHECK-NEXT: return [[res]] - return %res : tensor<12x12xi8> + // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8> + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref + // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref + // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> + // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> + // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> + // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> + // CHECK: return [[v1]] : tensor<120x120x120xi8> + return %res : tensor<120x120x120xi8> } From b5013d0a542e78714df9925083c8d5e78433373c Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Fri, 8 Nov 2024 18:52:09 +0100 Subject: [PATCH 13/15] using restrict --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index c51c5335fc609..1c82881f67d3a 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -400,7 +400,8 @@ struct ConvertUpdateHaloOp } else { assert(isa(op.getResult().getType())); rewriter.replaceOp(op, rewriter.create( - loc, op.getResult().getType(), array)); + loc, op.getResult().getType(), array, + /*restrict=*/true, /*writable=*/true)); } return mlir::success(); } From 1ad7725a3aa4a0fa0b9bc1a8fd8e07d33bfe3a51 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Mon, 11 Nov 2024 16:45:31 +0100 Subject: [PATCH 14/15] canonicalizing send and recv towrads static memref shapes --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 + mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 40 +++++++++++++++++++ .../MeshToMPI/convert-mesh-to-mpi.mlir | 34 ++++++---------- 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 768f376e24da4..240fac5104c34 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index ddd77b8f586ee..dcb55d8921364 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -7,12 +7,52 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::mpi; +namespace { + +// If input memref has dynamic shape and is a cast and if the cast's input has +// static shape, fold the cast's static input into the given operation. +template +struct FoldCast final : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + mlir::PatternRewriter &b) const override { + auto mRef = op.getRef(); + if (mRef.getType().hasStaticShape()) { + return mlir::failure(); + } + auto defOp = mRef.getDefiningOp(); + if (!defOp || !mlir::isa(defOp)) { + return mlir::failure(); + } + auto src = mlir::cast(defOp).getSource(); + if (!src.getType().hasStaticShape()) { + return mlir::failure(); + } + op.getRefMutable().assign(src); + return mlir::success(); + } +}; +} // namespace + +void mlir::mpi::SendOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + +void mlir::mpi::RecvOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index 38b7a12daef52..25d585a108c8a 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -115,34 +115,30 @@ func.func @update_halo_3d( // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 - // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 - // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref - // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> - // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> @@ -170,34 +166,30 @@ func.func @update_halo_3d_tensor( // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8> // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 - // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref, i32, i32 - // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref - // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> - // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> @@ -209,7 +201,7 @@ func.func @update_halo_3d_tensor( // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> - // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8> + // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> // CHECK: return [[v1]] : tensor<120x120x120xi8> return %res : tensor<120x120x120xi8> From c48c7f0b25ec0b5f9a22bbdba28c79057c298fb0 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 27 Nov 2024 16:48:25 +0100 Subject: [PATCH 15/15] fixing comments --- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 1c82881f67d3a..6dd89ecf4d5c2 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -72,7 +72,6 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, return linearIndex; } -// This pattern converts the mesh.update_halo operation to MPI calls struct ConvertProcessMultiIndexOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -80,6 +79,9 @@ struct ConvertProcessMultiIndexOp mlir::LogicalResult matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, mlir::PatternRewriter &rewriter) const override { + + // Currently converts its linear index to a multi-dimensional index. + SymbolTableCollection symbolTableCollection; auto loc = op.getLoc(); auto meshOp = getMesh(op, symbolTableCollection); @@ -112,9 +114,6 @@ struct ConvertProcessMultiIndexOp } }; -// This pattern converts the mesh.update_halo operation to MPI calls. -// If it finds a global named "static_mpi_rank" it will use that splat value. -// Otherwise it defaults to mpi.comm_rank. struct ConvertProcessLinearIndexOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -122,6 +121,10 @@ struct ConvertProcessLinearIndexOp mlir::LogicalResult matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, mlir::PatternRewriter &rewriter) const override { + + // Finds a global named "static_mpi_rank" it will use that splat value. + // Otherwise it defaults to mpi.comm_rank. + auto loc = op.getLoc(); auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank"); if (auto globalOp = SymbolTable::lookupNearestSymbolFrom( @@ -145,7 +148,6 @@ struct ConvertProcessLinearIndexOp } }; -// This pattern converts the mesh.update_halo operation to MPI calls struct ConvertNeighborsLinearIndicesOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -153,6 +155,11 @@ struct ConvertNeighborsLinearIndicesOp mlir::LogicalResult matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, mlir::PatternRewriter &rewriter) const override { + + // Computes the neighbors indices along a split axis by simply + // adding/subtracting 1 to the current index in that dimension. + // Assigns -1 if neighbor is out of bounds. + auto axes = op.getSplitAxes(); // For now only single axis sharding is supported if (axes.size() != 1) { @@ -209,7 +216,6 @@ struct ConvertNeighborsLinearIndicesOp } }; -// This pattern converts the mesh.update_halo operation to MPI calls struct ConvertUpdateHaloOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -217,6 +223,7 @@ struct ConvertUpdateHaloOp mlir::LogicalResult matchAndRewrite(mlir::mesh::UpdateHaloOp op, mlir::PatternRewriter &rewriter) const override { + // The input/output memref is assumed to be in C memory order. // Halos are exchanged as 2 blocks per dimension (one for each side: down // and up). For each haloed dimension `d`, the exchanged blocks are