diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h new file mode 100644 index 0000000000000..44a1cc0adb6a0 --- /dev/null +++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h @@ -0,0 +1,27 @@ +//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H +#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS +#include "mlir/Conversion/Passes.h.inc" + +/// Lowers Mesh communication operations (updateHalo, AllGater, ...) +/// to MPI primitives. +std::unique_ptr<::mlir::Pass> createConvertMeshToMPIPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 2ab32836c80b1..b577aa83946f2 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -51,6 +51,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" +#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 4d272ba219c6f..4d6be8d18d1fe 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -878,6 +878,29 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> { ]; } +//===----------------------------------------------------------------------===// +// MeshToMPI +//===----------------------------------------------------------------------===// + +def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { + let summary = "Convert Mesh dialect to MPI dialect."; + let constructor = "mlir::createConvertMeshToMPIPass()"; + let description = [{ + This pass converts communication operations from the Mesh dialect to the + MPI dialect. + If it finds a global named "static_mpi_rank" it will use that splat value + instead of calling MPI_Comm_rank. This allows optimizations like constant + shape propagation and fusion because shard/partition sizes depend on the + rank. + }]; + let dependentDialects = [ + "memref::MemRefDialect", + "mpi::MPIDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect" + ]; +} + //===----------------------------------------------------------------------===// // NVVMToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 768f376e24da4..240fac5104c34 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 19498fe5a32d6..6039e61a93fad 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -156,6 +156,40 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ ]; } +def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ + Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = + "For given mesh index get the linear indices of the direct neighbor processes along the given split."; + let description = [{ + Example: + ``` + mesh.mesh @mesh0(shape = 10x20x30) + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index + ``` + The above returns two indices, `633` and `693`, which correspond to the + index of the previous process `(1, 1, 3)`, and the next process + `(1, 3, 3) along the split axis `1`. + + A negative value is returned if there is no neighbor in the respective + direction along the given `split_axes`. + }]; + let arguments = (ins FlatSymbolRefAttr:$mesh, + Variadic:$device, + Mesh_MeshAxesAttr:$split_axes); + let results = (outs Index:$neighbor_down, Index:$neighbor_up); + let assemblyFormat = [{ + `on` $mesh `[` $device `]` + `split_axes` `=` $split_axes + attr-dict `:` type(results) + }]; +} + //===----------------------------------------------------------------------===// // Sharding operations. //===----------------------------------------------------------------------===// @@ -1058,12 +1092,12 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ } def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ + Pure, DestinationStyleOpInterface, TypesMatchWith< "result has same type as destination", "result", "destination", "$_self">, - DeclareOpInterfaceMethods, - AttrSizedOperandSegments + DeclareOpInterfaceMethods ]> { let summary = "Update halo data."; let description = [{ @@ -1072,7 +1106,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones. - Source and destination might have different halo sizes. + Destination is supposed to be initialized with the local data (not halos). Assumes all devices hold tensors with same-sized halo data as specified by `source_halo_sizes/static_source_halo_sizes` and @@ -1084,25 +1118,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ }]; let arguments = (ins - AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source, AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination, FlatSymbolRefAttr:$mesh, Mesh_MeshAxesArrayAttr:$split_axes, - Variadic:$source_halo_sizes, - DefaultValuedAttr:$static_source_halo_sizes, - Variadic:$destination_halo_sizes, - DefaultValuedAttr:$static_destination_halo_sizes + Variadic:$halo_sizes, + DefaultValuedAttr:$static_halo_sizes ); let results = (outs AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result ); let assemblyFormat = [{ - $source `into` $destination + $destination `on` $mesh `split_axes` `=` $split_axes - (`source_halo_sizes` `=` custom($source_halo_sizes, $static_source_halo_sizes)^)? - (`destination_halo_sizes` `=` custom($destination_halo_sizes, $static_destination_halo_sizes)^)? - attr-dict `:` type($source) `->` type($result) + (`halo_sizes` `=` custom($halo_sizes, $static_halo_sizes)^)? + attr-dict `:` type($result) }]; let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); } diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 6651d87162257..62461c0cea08a 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) +add_subdirectory(MeshToMPI) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt new file mode 100644 index 0000000000000..95815a683f6d6 --- /dev/null +++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMeshToMPI + MeshToMPI.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRIR + MLIRLinalgTransforms + MLIRMemRefDialect + MLIRPass + MLIRMeshDialect + MLIRMPIDialect + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp new file mode 100644 index 0000000000000..6dd89ecf4d5c2 --- /dev/null +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -0,0 +1,440 @@ +//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation of Mesh communication ops tp MPI ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "mesh-to-mpi" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::mesh; + +namespace { +// Create operations converting a linear index to a multi-dimensional index +static SmallVector linearToMultiIndex(Location loc, OpBuilder b, + Value linearIndex, + ValueRange dimensions) { + int n = dimensions.size(); + SmallVector multiIndex(n); + + for (int i = n - 1; i >= 0; --i) { + multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); + if (i > 0) { + linearIndex = b.create(loc, linearIndex, dimensions[i]); + } + } + + return multiIndex; +} + +// Create operations converting a multi-dimensional index to a linear index +Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, + ValueRange dimensions) { + + auto linearIndex = b.create(loc, 0).getResult(); + auto stride = b.create(loc, 1).getResult(); + + for (int i = multiIndex.size() - 1; i >= 0; --i) { + auto off = b.create(loc, multiIndex[i], stride); + linearIndex = b.create(loc, linearIndex, off); + stride = b.create(loc, stride, dimensions[i]); + } + + return linearIndex; +} + +struct ConvertProcessMultiIndexOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, + mlir::PatternRewriter &rewriter) const override { + + // Currently converts its linear index to a multi-dimensional index. + + SymbolTableCollection symbolTableCollection; + auto loc = op.getLoc(); + auto meshOp = getMesh(op, symbolTableCollection); + // For now we only support static mesh shapes + if (ShapedType::isDynamicShape(meshOp.getShape())) { + return mlir::failure(); + } + + SmallVector dims; + llvm::transform( + meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); + auto rank = + rewriter.create(op.getLoc(), meshOp).getResult(); + auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); + + // optionally extract subset of mesh axes + auto axes = op.getAxes(); + if (!axes.empty()) { + SmallVector subIndex; + for (auto axis : axes) { + subIndex.push_back(mIdx[axis]); + } + mIdx = subIndex; + } + + rewriter.replaceOp(op, mIdx); + return mlir::success(); + } +}; + +struct ConvertProcessLinearIndexOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, + mlir::PatternRewriter &rewriter) const override { + + // Finds a global named "static_mpi_rank" it will use that splat value. + // Otherwise it defaults to mpi.comm_rank. + + auto loc = op.getLoc(); + auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank"); + if (auto globalOp = SymbolTable::lookupNearestSymbolFrom( + op, rankOpName)) { + if (auto initTnsr = globalOp.getInitialValueAttr()) { + auto val = cast(initTnsr).getSplatValue(); + rewriter.replaceOp(op, + rewriter.create(loc, val)); + return mlir::success(); + } + } + auto rank = + rewriter + .create( + op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), + rewriter.getI32Type()}) + .getRank(); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + rank); + return mlir::success(); + } +}; + +struct ConvertNeighborsLinearIndicesOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, + mlir::PatternRewriter &rewriter) const override { + + // Computes the neighbors indices along a split axis by simply + // adding/subtracting 1 to the current index in that dimension. + // Assigns -1 if neighbor is out of bounds. + + auto axes = op.getSplitAxes(); + // For now only single axis sharding is supported + if (axes.size() != 1) { + return mlir::failure(); + } + + auto loc = op.getLoc(); + SymbolTableCollection symbolTableCollection; + auto meshOp = getMesh(op, symbolTableCollection); + auto mIdx = op.getDevice(); + auto orgIdx = mIdx[axes[0]]; + SmallVector dims; + llvm::transform( + meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return rewriter.create(loc, i).getResult(); + }); + auto dimSz = dims[axes[0]]; + auto one = rewriter.create(loc, 1).getResult(); + auto minus1 = rewriter.create(loc, -1).getResult(); + auto atBorder = rewriter.create( + loc, arith::CmpIPredicate::sle, orgIdx, + rewriter.create(loc, 0).getResult()); + auto down = rewriter.create( + loc, atBorder, + [&](OpBuilder &builder, Location loc) { + builder.create(loc, minus1); + }, + [&](OpBuilder &builder, Location loc) { + SmallVector tmp = mIdx; + tmp[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, one) + .getResult(); + builder.create( + loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + }); + atBorder = rewriter.create( + loc, arith::CmpIPredicate::sge, orgIdx, + rewriter.create(loc, dimSz, one).getResult()); + auto up = rewriter.create( + loc, atBorder, + [&](OpBuilder &builder, Location loc) { + builder.create(loc, minus1); + }, + [&](OpBuilder &builder, Location loc) { + SmallVector tmp = mIdx; + tmp[axes[0]] = + rewriter.create(op.getLoc(), orgIdx, one) + .getResult(); + builder.create( + loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + }); + rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); + return mlir::success(); + } +}; + +struct ConvertUpdateHaloOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::mesh::UpdateHaloOp op, + mlir::PatternRewriter &rewriter) const override { + + // The input/output memref is assumed to be in C memory order. + // Halos are exchanged as 2 blocks per dimension (one for each side: down + // and up). For each haloed dimension `d`, the exchanged blocks are + // expressed as multi-dimensional subviews. The subviews include potential + // halos of higher dimensions `dh > d`, no halos for the lower dimensions + // `dl < d` and for dimension `d` the currently exchanged halo only. + // By iterating form higher to lower dimensions this also updates the halos + // in the 'corners'. + // memref.subview is used to read and write the halo data from and to the + // local data. Because subviews and halos can have mixed dynamic and static + // shapes, OpFoldResults are used whenever possible. + + SymbolTableCollection symbolTableCollection; + auto loc = op.getLoc(); + + // convert a OpFoldResult into a Value + auto toValue = [&rewriter, &loc](OpFoldResult &v) { + return v.is() + ? v.get() + : rewriter.create<::mlir::arith::ConstantOp>( + loc, + rewriter.getIndexAttr( + cast(v.get()).getInt())); + }; + + auto dest = op.getDestination(); + auto dstShape = cast(dest.getType()).getShape(); + Value array = dest; + if (isa(array.getType())) { + // If the destination is a memref, we need to cast it to a tensor + auto tensorType = MemRefType::get( + dstShape, cast(array.getType()).getElementType()); + array = rewriter.create(loc, tensorType, array) + .getResult(); + } + auto rank = cast(array.getType()).getRank(); + auto opSplitAxes = op.getSplitAxes().getAxes(); + auto mesh = op.getMesh(); + auto meshOp = getMesh(op, symbolTableCollection); + auto haloSizes = + getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter); + // subviews need Index values + for (auto &sz : haloSizes) { + if (sz.is()) { + sz = rewriter + .create(loc, rewriter.getIndexType(), + sz.get()) + .getResult(); + } + } + + // most of the offset/size/stride data is the same for all dims + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector shape(rank), dimSizes(rank); + auto currHaloDim = -1; // halo sizes are provided for split dimensions only + // we need the actual shape to compute offsets and sizes + for (auto i = 0; i < rank; ++i) { + auto s = dstShape[i]; + if (ShapedType::isDynamic(s)) { + shape[i] = rewriter.create(loc, array, s).getResult(); + } else { + shape[i] = rewriter.getIndexAttr(s); + } + + if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { + ++currHaloDim; + // the offsets for lower dim sstarts after their down halo + offsets[i] = haloSizes[currHaloDim * 2]; + + // prepare shape and offsets of highest dim's halo exchange + auto _haloSz = + rewriter + .create(loc, toValue(haloSizes[currHaloDim * 2]), + toValue(haloSizes[currHaloDim * 2 + 1])) + .getResult(); + // the halo shape of lower dims exlude the halos + dimSizes[i] = + rewriter.create(loc, toValue(shape[i]), _haloSz) + .getResult(); + } else { + dimSizes[i] = shape[i]; + } + } + + auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something + auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr); + auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 + auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr); + + SmallVector indexResultTypes(meshOp.getShape().size(), + rewriter.getIndexType()); + auto myMultiIndex = + rewriter.create(loc, indexResultTypes, mesh) + .getResult(); + // traverse all split axes from high to low dim + for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { + auto splitAxes = opSplitAxes[dim]; + if (splitAxes.empty()) { + continue; + } + assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); + // Get the linearized ids of the neighbors (down and up) for the + // given split + auto tmp = rewriter + .create(loc, mesh, myMultiIndex, + splitAxes) + .getResults(); + // MPI operates on i32... + Value neighbourIDs[2] = {rewriter.create( + loc, rewriter.getI32Type(), tmp[0]), + rewriter.create( + loc, rewriter.getI32Type(), tmp[1])}; + + auto lowerRecvOffset = rewriter.getIndexAttr(0); + auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); + auto upperRecvOffset = rewriter.create( + loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); + auto upperSendOffset = rewriter.create( + loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); + + // Make sure we send/recv in a way that does not lead to a dead-lock. + // The current approach is by far not optimal, this should be at least + // be a red-black pattern or using MPI_sendrecv. + // Also, buffers should be re-used. + // Still using temporary contiguous buffers for MPI communication... + // Still yielding a "serialized" communication pattern... + auto genSendRecv = [&](bool upperHalo) { + auto orgOffset = offsets[dim]; + dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] + : haloSizes[currHaloDim * 2]; + // Check if we need to send and/or receive + // Processes on the mesh borders have only one neighbor + auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; + auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; + auto hasFrom = rewriter.create( + loc, arith::CmpIPredicate::sge, from, zero); + auto hasTo = rewriter.create( + loc, arith::CmpIPredicate::sge, to, zero); + auto buffer = rewriter.create( + loc, dimSizes, cast(array.getType()).getElementType()); + // if has neighbor: copy halo data from array to buffer and send + rewriter.create( + loc, hasTo, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) + : OpFoldResult(upperSendOffset); + auto subview = builder.create( + loc, array, offsets, dimSizes, strides); + builder.create(loc, subview, buffer); + builder.create(loc, TypeRange{}, buffer, tag, to); + builder.create(loc); + }); + // if has neighbor: receive halo data into buffer and copy to array + rewriter.create( + loc, hasFrom, [&](OpBuilder &builder, Location loc) { + offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) + : OpFoldResult(lowerRecvOffset); + builder.create(loc, TypeRange{}, buffer, tag, from); + auto subview = builder.create( + loc, array, offsets, dimSizes, strides); + builder.create(loc, buffer, subview); + builder.create(loc); + }); + rewriter.create(loc, buffer); + offsets[dim] = orgOffset; + }; + + genSendRecv(false); + genSendRecv(true); + + // the shape for lower dims include higher dims' halos + dimSizes[dim] = shape[dim]; + // -> the offset for higher dims is always 0 + offsets[dim] = rewriter.getIndexAttr(0); + // on to next halo + --currHaloDim; + } + + if (isa(op.getResult().getType())) { + rewriter.replaceOp(op, array); + } else { + assert(isa(op.getResult().getType())); + rewriter.replaceOp(op, rewriter.create( + loc, op.getResult().getType(), array, + /*restrict=*/true, /*writable=*/true)); + } + return mlir::success(); + } +}; + +struct ConvertMeshToMPIPass + : public impl::ConvertMeshToMPIPassBase { + using Base::Base; + + /// Run the dialect converter on the module. + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + + patterns.insert( + ctx); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)); + } +}; + +} // namespace + +// Create a pass that convert Mesh to MPI +std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index ddd77b8f586ee..dcb55d8921364 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -7,12 +7,52 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::mpi; +namespace { + +// If input memref has dynamic shape and is a cast and if the cast's input has +// static shape, fold the cast's static input into the given operation. +template +struct FoldCast final : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + mlir::PatternRewriter &b) const override { + auto mRef = op.getRef(); + if (mRef.getType().hasStaticShape()) { + return mlir::failure(); + } + auto defOp = mRef.getDefiningOp(); + if (!defOp || !mlir::isa(defOp)) { + return mlir::failure(); + } + auto src = mlir::cast(defOp).getSource(); + if (!src.getType().hasStaticShape()) { + return mlir::failure(); + } + op.getRefMutable().assign(src); + return mlir::success(); + } +}; +} // namespace + +void mlir::mpi::SendOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + +void mlir::mpi::RecvOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index c5570d8ee8a44..33460ff25e9e4 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -837,6 +837,25 @@ void ProcessLinearIndexOp::getAsmResultNames( setNameFn(getResult(), "proc_linear_idx"); } +//===----------------------------------------------------------------------===// +// mesh.neighbors_linear_indices op +//===----------------------------------------------------------------------===// + +LogicalResult +NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + return success(); +} + +void NeighborsLinearIndicesOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getNeighborDown(), "down_linear_idx"); + setNameFn(getNeighborUp(), "up_linear_idx"); +} + //===----------------------------------------------------------------------===// // collective communication ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index b4d088cbd7088..327ea0991e4e1 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -496,11 +496,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, sourceShard.getLoc(), RankedTensorType::get(outShape, sourceShard.getType().getElementType()), - sourceShard, initOprnd, mesh.getSymName(), + initOprnd, mesh.getSymName(), MeshAxesArrayAttr::get(builder.getContext(), sourceSharding.getSplitAxes()), - sourceSharding.getDynamicHaloSizes(), - sourceSharding.getStaticHaloSizes(), targetSharding.getDynamicHaloSizes(), targetSharding.getStaticHaloSizes()) .getResult(); diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir new file mode 100644 index 0000000000000..25d585a108c8a --- /dev/null +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -0,0 +1,208 @@ +// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s + +// ----- +// CHECK: mesh.mesh @mesh0 +mesh.mesh @mesh0(shape = 3x4x5) +func.func @process_multi_index() -> (index, index, index) { + // CHECK: mpi.comm_rank : !mpi.retval, i32 + // CHECK-DAG: %[[v4:.*]] = arith.remsi + // CHECK-DAG: %[[v0:.*]] = arith.remsi + // CHECK-DAG: %[[v1:.*]] = arith.remsi + %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @process_linear_index +func.func @process_linear_index() -> index { + // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32 + // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index + %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: return %[[cast]] : index + return %0 : index +} + +// CHECK-LABEL: func @neighbors_dim0 +func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index + // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// CHECK-LABEL: func @neighbors_dim1 +func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index + // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// CHECK-LABEL: func @neighbors_dim2 +func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index + // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index + %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index + // CHECK: return [[down]], [[up]] : index, index + return %idx#0, %idx#1 : index, index +} + +// ----- +// CHECK: mesh.mesh @mesh0 +mesh.mesh @mesh0(shape = 3x4x5) +memref.global constant @static_mpi_rank : memref = dense<24> +func.func @process_multi_index() -> (index, index, index) { + // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index +} + +// CHECK-LABEL: func @process_linear_index +func.func @process_linear_index() -> index { + // CHECK: %[[c24:.*]] = arith.constant 24 : index + %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: return %[[c24]] : index + return %0 : index +} + +// ----- +mesh.mesh @mesh0(shape = 3x4x5) +// CHECK-LABEL: func @update_halo_1d_first +func.func @update_halo_1d_first( + // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK: mpi.send( + // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 + // CHECK: mpi.recv( + // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 + // CHECK: mpi.send( + // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 + // CHECK: mpi.recv( + // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> + // CHECK: return [[res:%.*]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> +} + +// ----- +mesh.mesh @mesh0(shape = 3x4x5) +memref.global constant @static_mpi_rank : memref = dense<24> +// CHECK-LABEL: func @update_halo_3d +func.func @update_halo_3d( + // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> + // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> + // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> + // CHECK: return [[varg0]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> +} + +// CHECK-LABEL: func @update_halo_3d_tensor +func.func @update_halo_3d_tensor( + // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> + %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { + // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8> + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> + // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> + // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> + // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> + // CHECK: return [[v1]] : tensor<120x120x120xi8> + return %res : tensor<120x120x120xi8> +} diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index d8df01c3d6520..978de4939ee77 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -615,16 +615,16 @@ func.func @update_halo( // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 - // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0 + // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0 // CHECK-SAME: split_axes = {{\[\[}}0]] - // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8> + // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> %c2 = arith.constant 2 : i64 - %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]] - source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8> - // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0 + %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + halo_sizes = [2, %c2] : memref<12x12xi8> + // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0 // CHECK-SAME: split_axes = {{\[\[}}0], [1]] - // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8> - %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]] - source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8> + // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> + %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]] + halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> return } diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 22ddb72569835..c1b96fda0f4a7 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64> + // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> @@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200 %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64> + // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>