diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index cccdf0a8518bf..6074e0e8d822c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { let description = [{ This pass converts communication operations from the Mesh dialect to the MPI dialect. - If it finds a global named "static_mpi_rank" it will use that splat value - instead of calling MPI_Comm_rank. This allows optimizations like constant - shape propagation and fusion because shard/partition sizes depend on the - rank. + If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will + use that integer value instead of calling MPI_Comm_rank. This allows + optimizations like constant shape propagation and fusion because + shard/partition sizes depend on the rank. }]; let dependentDialects = [ "memref::MemRefDialect", diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 031e6f63bcb42..f59c4c4c67517 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> { }]; } -def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> { - let summary = "Get the shard shape of a given process/device."; +def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [ + Pure, AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let summary = "Get the shard shape for a given process/device."; let description = [{ - The device/process id is a linearized id of the device/process in the mesh. + The device/process id is a multi-index of the device/process in the mesh. This operation might be used during spmdization when the shard shape depends on (non-constant) values used in `mesh.sharding`. }]; let arguments = (ins - DenseI64ArrayAttr:$shape, + DenseI64ArrayAttr:$dims, + Variadic:$dims_dynamic, Mesh_Sharding:$sharding, - Index:$device + DenseI64ArrayAttr:$device, + Variadic:$device_dynamic ); let results = (outs Variadic:$result); let assemblyFormat = [{ - custom($shape) $sharding $device attr-dict `:` type($result) + `dims` `=` custom($dims_dynamic, $dims) + `sharding` `=` $sharding + `device` `=` custom($device_dynamic, $device) + attr-dict `:` type(results) }]; let builders = [ - OpBuilder<(ins "ArrayRef":$shape, "Value":$sharding, "Value":$device)> + OpBuilder<(ins "ArrayRef":$dims, "ArrayRef":$dims_dyn, "Value":$sharding, "ValueRange":$device)> ]; } diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt index 95815a683f6d6..15560aa61e145 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI Core LINK_LIBS PUBLIC + MLIRDLTIDialect MLIRFuncDialect MLIRIR MLIRLinalgTransforms diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 48b3764d520c2..76b1978e6a025 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -14,8 +14,13 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -25,6 +30,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "mesh-to-mpi" @@ -36,10 +42,34 @@ namespace mlir { } // namespace mlir using namespace mlir; -using namespace mlir::mesh; +using namespace mesh; namespace { -// Create operations converting a linear index to a multi-dimensional index +/// Converts a vector of OpFoldResults (ints) into vector of Values of the +/// provided type. +static SmallVector getMixedAsValues(OpBuilder b, const Location &loc, + llvm::ArrayRef statics, + ValueRange dynamics, + Type type = Type()) { + SmallVector values; + auto dyn = dynamics.begin(); + Type i64 = b.getI64Type(); + if (!type) + type = i64; + assert((i64 == type || b.getIndexType() == type) && + "expected an i64 or an intex type"); + for (auto s : statics) { + if (s == ShapedType::kDynamic) { + values.emplace_back(*(dyn++)); + } else { + TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); + values.emplace_back(b.create(loc, type, val)); + } + } + return values; +}; + +/// Create operations converting a linear index to a multi-dimensional index. static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) { @@ -48,23 +78,22 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, for (int i = n - 1; i >= 0; --i) { multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); - if (i > 0) { + if (i > 0) linearIndex = b.create(loc, linearIndex, dimensions[i]); - } } return multiIndex; } -// Create operations converting a multi-dimensional index to a linear index +/// Create operations converting a multi-dimensional index to a linear index. Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { - auto linearIndex = b.create(loc, 0).getResult(); - auto stride = b.create(loc, 1).getResult(); + Value linearIndex = b.create(loc, 0); + Value stride = b.create(loc, 1); for (int i = multiIndex.size() - 1; i >= 0; --i) { - auto off = b.create(loc, multiIndex[i], stride); + Value off = b.create(loc, multiIndex[i], stride); linearIndex = b.create(loc, linearIndex, off); stride = b.create(loc, stride, dimensions[i]); } @@ -72,116 +101,259 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, return linearIndex; } +/// Replace GetShardingOp with related/dependent ShardingOp. +struct ConvertGetShardingOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto shardOp = adaptor.getSource().getDefiningOp(); + if (!shardOp) + return failure(); + auto shardingOp = shardOp.getSharding().getDefiningOp(); + if (!shardingOp) + return failure(); + + rewriter.replaceOp(op, shardingOp.getResult()); + return success(); + } +}; + +/// Convert a sharding op to a tuple of tensors of its components +/// (SplitAxes, HaloSizes, ShardedDimsOffsets) +/// as defined by type converter. +struct ConvertShardingOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShardingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto splitAxes = op.getSplitAxes().getAxes(); + int64_t maxNAxes = 0; + for (auto axes : splitAxes) + maxNAxes = std::max(maxNAxes, axes.size()); + + // To hold the split axes, create empty 2d tensor with shape + // {splitAxes.size(), max-size-of-split-groups}. + // Set trailing elements for smaller split-groups to -1. + Location loc = op.getLoc(); + auto i16 = rewriter.getI16Type(); + auto i64 = rewriter.getI64Type(); + std::array shape = {static_cast(splitAxes.size()), + maxNAxes}; + Value resSplitAxes = rewriter.create(loc, shape, i16); + auto attr = IntegerAttr::get(i16, -1); + Value fillValue = rewriter.create(loc, i16, attr); + resSplitAxes = rewriter.create(loc, fillValue, resSplitAxes) + .getResult(0); + + // explicitly write values into tensor row by row + std::array strides = {1, 1}; + int64_t nSplits = 0; + ValueRange empty = {}; + for (auto [i, axes] : llvm::enumerate(splitAxes)) { + int64_t size = axes.size(); + if (size > 0) + ++nSplits; + std::array offs = {(int64_t)i, 0}; + std::array sizes = {1, size}; + auto tensorType = RankedTensorType::get({size}, i16); + auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); + auto vals = rewriter.create(loc, tensorType, attrs); + resSplitAxes = rewriter.create( + loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); + } + + // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. + // Store the halo sizes in the tensor. + SmallVector haloSizes = + getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(), + adaptor.getDynamicHaloSizes()); + auto type = RankedTensorType::get({nSplits, 2}, i64); + Value resHaloSizes = + haloSizes.empty() + ? rewriter + .create(loc, std::array{0, 0}, + i64) + .getResult() + : rewriter.create(loc, type, haloSizes) + .getResult(); + + // To hold sharded dims offsets, create Tensor with shape {nSplits, + // maxSplitSize+1}. Store the offsets in the tensor but set trailing + // elements for smaller split-groups to -1. Computing the max size of the + // split groups needs using collectiveProcessGroupSize (which needs the + // MeshOp) + Value resOffsets; + if (adaptor.getStaticShardedDimsOffsets().empty()) { + resOffsets = rewriter.create( + loc, std::array{0, 0}, i64); + } else { + SymbolTableCollection symbolTableCollection; + auto meshOp = getMesh(op, symbolTableCollection); + int64_t maxSplitSize = 0; + for (auto axes : splitAxes) { + int64_t splitSize = + collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + assert(splitSize != ShapedType::kDynamic); + maxSplitSize = std::max(maxSplitSize, splitSize); + } + assert(maxSplitSize); + ++maxSplitSize; // add one for the total size + + resOffsets = rewriter.create( + loc, std::array{nSplits, maxSplitSize}, i64); + Value zero = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); + resOffsets = + rewriter.create(loc, zero, resOffsets).getResult(0); + SmallVector offsets = + getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), + adaptor.getDynamicShardedDimsOffsets()); + int64_t curr = 0; + for (auto [i, axes] : llvm::enumerate(splitAxes)) { + int64_t splitSize = + collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); + ++splitSize; // add one for the total size + ArrayRef values(&offsets[curr], splitSize); + Value vals = rewriter.create(loc, values); + std::array offs = {static_cast(i), 0}; + std::array sizes = {1, splitSize}; + resOffsets = rewriter.create( + loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); + curr += splitSize; + } + } + + // return a tuple of tensors as defined by type converter + SmallVector resTypes; + if (failed(getTypeConverter()->convertType(op.getResult().getType(), + resTypes))) + return failure(); + + resSplitAxes = + rewriter.create(loc, resTypes[0], resSplitAxes); + resHaloSizes = + rewriter.create(loc, resTypes[1], resHaloSizes); + resOffsets = rewriter.create(loc, resTypes[2], resOffsets); + + rewriter.replaceOpWithNewOp( + op, TupleType::get(op.getContext(), resTypes), + ValueRange{resSplitAxes, resHaloSizes, resOffsets}); + + return success(); + } +}; + struct ConvertProcessMultiIndexOp - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, - mlir::PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // Currently converts its linear index to a multi-dimensional index. SymbolTableCollection symbolTableCollection; - auto loc = op.getLoc(); + Location loc = op.getLoc(); auto meshOp = getMesh(op, symbolTableCollection); // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) { - return mlir::failure(); - } + if (ShapedType::isDynamicShape(meshOp.getShape())) + return failure(); SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return rewriter.create(loc, i).getResult(); }); - auto rank = - rewriter.create(op.getLoc(), meshOp).getResult(); + Value rank = rewriter.create(op.getLoc(), meshOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); // optionally extract subset of mesh axes - auto axes = op.getAxes(); + auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector subIndex; for (auto axis : axes) { - subIndex.push_back(mIdx[axis]); + subIndex.emplace_back(mIdx[axis]); } - mIdx = subIndex; + mIdx = std::move(subIndex); } rewriter.replaceOp(op, mIdx); - return mlir::success(); + return success(); } }; -struct ConvertProcessLinearIndexOp - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, - mlir::PatternRewriter &rewriter) const override { - - // Finds a global named "static_mpi_rank" it will use that splat value. - // Otherwise it defaults to mpi.comm_rank. - - auto loc = op.getLoc(); - auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank"); - if (auto globalOp = SymbolTable::lookupNearestSymbolFrom( - op, rankOpName)) { - if (auto initTnsr = globalOp.getInitialValueAttr()) { - auto val = cast(initTnsr).getSplatValue(); - rewriter.replaceOp(op, - rewriter.create(loc, val)); - return mlir::success(); - } +class ConvertProcessLinearIndexOp + : public OpConversionPattern { + int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0 + +public: + using OpConversionPattern::OpConversionPattern; + + // Constructor accepting worldRank + ConvertProcessLinearIndexOp(const TypeConverter &typeConverter, + MLIRContext *context, int64_t worldRank = -1) + : OpConversionPattern(typeConverter, context), worldRank(worldRank) {} + + LogicalResult + matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it + rewriter.replaceOpWithNewOp(op, worldRank); + return success(); } - auto rank = - rewriter - .create( - op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), + + // Otherwise call create mpi::CommRankOp + auto rank = rewriter + .create( + loc, TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}) - .getRank(); + .getRank(); rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), rank); - return mlir::success(); + return success(); } }; struct ConvertNeighborsLinearIndicesOp - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, - mlir::PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // Computes the neighbors indices along a split axis by simply // adding/subtracting 1 to the current index in that dimension. // Assigns -1 if neighbor is out of bounds. - auto axes = op.getSplitAxes(); + auto axes = adaptor.getSplitAxes(); // For now only single axis sharding is supported - if (axes.size() != 1) { - return mlir::failure(); - } + if (axes.size() != 1) + return failure(); - auto loc = op.getLoc(); + Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; auto meshOp = getMesh(op, symbolTableCollection); - auto mIdx = op.getDevice(); + auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return rewriter.create(loc, i).getResult(); }); - auto dimSz = dims[axes[0]]; - auto one = rewriter.create(loc, 1).getResult(); - auto minus1 = rewriter.create(loc, -1).getResult(); - auto atBorder = rewriter.create( + Value dimSz = dims[axes[0]]; + Value one = rewriter.create(loc, 1); + Value minus1 = rewriter.create(loc, -1); + Value atBorder = rewriter.create( loc, arith::CmpIPredicate::sle, orgIdx, - rewriter.create(loc, 0).getResult()); + rewriter.create(loc, 0)); auto down = rewriter.create( loc, atBorder, [&](OpBuilder &builder, Location loc) { @@ -206,23 +378,160 @@ struct ConvertNeighborsLinearIndicesOp [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, one) - .getResult(); + rewriter.create(op.getLoc(), orgIdx, one); builder.create( loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); - return mlir::success(); + return success(); + } +}; + +struct ConvertShardShapeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sharding = op.getSharding().getDefiningOp(); + if (!sharding) { + return op->emitError() + << "Expected SharingOp as defining op for sharding" + << " but found " << adaptor.getSharding()[0].getDefiningOp(); + } + + // Compute the sharded shape by applying the sharding to the input shape. + // If shardedDimsOffsets is not defined in the sharding, the shard shape is + // computed by dividing the dimension size by the number of shards in that + // dimension (which is given by the size of the mesh axes provided in + // split-axes). Odd elements get distributed to trailing shards. If a + // shardedDimsOffsets is provided, the shard shape is computed by + // subtracting the offset of the current shard from the offset of the next + // shard. + + Location loc = op.getLoc(); + Type index = rewriter.getIndexType(); + + // This is a 1:N conversion because the sharding op is a 1:3 conversion. + // The operands in the adaptor are a vector. For dims and device + // we have a 1:1 conversion. + // For simpler access fill a vector with the dynamic dims. + SmallVector dynDims, dynDevice; + for (auto dim : adaptor.getDimsDynamic()) { + // type conversion should be 1:1 for ints + assert(dim.size() == 1); + dynDims.emplace_back(dim[0]); + } + // same for device + for (auto device : adaptor.getDeviceDynamic()) { + assert(device.size() == 1); + dynDevice.emplace_back(device[0]); + } + + // To keep the code simple, convert dims/device to values when they are + // attributes. Count on canonicalization to fold static values. + SmallVector shape = + getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index); + SmallVector multiIdx = + getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); + + // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + SymbolTableCollection symbolTableCollection; + auto meshOp = getMesh(sharding, symbolTableCollection); + // For now we only support static mesh shapes + if (ShapedType::isDynamicShape(meshOp.getShape())) + return failure(); + + auto splitAxes = sharding.getSplitAxes().getAxes(); + // shardedDimsOffsets are optional and might be Values (not attributes). + // Also, the shardId might be dynamic which means the position in the + // shardedDimsOffsets is not statically known. Create a tensor of the + // shardedDimsOffsets and later extract the offsets for computing the + // local shard-size. + Value shardedDimsOffs; + { + SmallVector tmp = getMixedAsValues( + rewriter, loc, sharding.getStaticShardedDimsOffsets(), + sharding.getDynamicShardedDimsOffsets(), index); + if (!tmp.empty()) + shardedDimsOffs = rewriter.create( + loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); + } + + // With static mesh shape the sizes of the split axes are known. + // Hence the start/pos for each split axes in shardDimsOffsets can be + // computed statically. + int64_t pos = 0; + SmallVector shardShape; + Value zero = + rewriter.create(loc, rewriter.getZeroAttr(index)); + Value one = + rewriter.create(loc, rewriter.getOneAttr(index)); + + // Iterate over the dimensions of the tensor shape, get their split Axes, + // and compute the sharded shape. + for (auto [i, dim] : llvm::enumerate(shape)) { + // Trailing dimensions might not be annotated. + if (i < splitAxes.size() && !splitAxes[i].empty()) { + auto axes = splitAxes[i]; + // The current dimension might not be sharded. + // Create a value from the static position in shardDimsOffsets. + Value posVal = + rewriter.create(loc, rewriter.getIndexAttr(pos)); + // Get the index of the local shard in the mesh axis. + Value idx = multiIdx[axes[0]]; + auto numShards = + collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + if (shardedDimsOffs) { + // If sharded dims offsets are provided, use them to compute the + // sharded shape. + if (axes.size() > 1) { + return op->emitError() << "Only single axis sharding is " + << "supported for each dimension."; + } + idx = rewriter.create(loc, posVal, idx); + // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. + Value off = + rewriter.create(loc, shardedDimsOffs, idx); + idx = rewriter.create(loc, idx, one); + Value nextOff = + rewriter.create(loc, shardedDimsOffs, idx); + Value sz = rewriter.create(loc, nextOff, off); + shardShape.emplace_back(sz); + } else { + Value numShardsVal = rewriter.create( + loc, rewriter.getIndexAttr(numShards)); + // Compute shard dim size by distributing odd elements to trailing + // shards: + // sz = dim / numShards + // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) + Value sz = rewriter.create(loc, dim, numShardsVal); + Value sz1 = rewriter.create(loc, dim, numShardsVal); + sz1 = rewriter.create(loc, numShardsVal, sz1); + auto cond = rewriter.create( + loc, arith::CmpIPredicate::sge, idx, sz1); + Value odd = rewriter.create(loc, cond, one, zero); + sz = rewriter.create(loc, sz, odd); + shardShape.emplace_back(sz); + } + pos += numShards + 1; // add one for the total size. + } // else no sharding if split axis is empty or no split axis + // If no size was added -> no sharding in this dimension. + if (shardShape.size() <= i) + shardShape.emplace_back(dim); + } + assert(shardShape.size() == shape.size()); + rewriter.replaceOp(op, shardShape); + return success(); } }; -struct ConvertUpdateHaloOp - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConvertUpdateHaloOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(mlir::mesh::UpdateHaloOp op, - mlir::PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // The input/output memref is assumed to be in C memory order. // Halos are exchanged as 2 blocks per dimension (one for each side: down @@ -236,42 +545,47 @@ struct ConvertUpdateHaloOp // local data. Because subviews and halos can have mixed dynamic and static // shapes, OpFoldResults are used whenever possible. + auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(), + adaptor.getHaloSizes(), rewriter); + if (haloSizes.empty()) { + // no halos -> nothing to do + rewriter.replaceOp(op, adaptor.getDestination()); + return success(); + } + SymbolTableCollection symbolTableCollection; - auto loc = op.getLoc(); + Location loc = op.getLoc(); // convert a OpFoldResult into a Value auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { if (auto value = dyn_cast(v)) return value; - return rewriter.create<::mlir::arith::ConstantOp>( + return rewriter.create( loc, rewriter.getIndexAttr( cast(cast(v)).getInt())); }; - auto dest = op.getDestination(); + auto dest = adaptor.getDestination(); auto dstShape = cast(dest.getType()).getShape(); Value array = dest; if (isa(array.getType())) { // If the destination is a memref, we need to cast it to a tensor auto tensorType = MemRefType::get( dstShape, cast(array.getType()).getElementType()); - array = rewriter.create(loc, tensorType, array) - .getResult(); + array = + rewriter.create(loc, tensorType, array); } auto rank = cast(array.getType()).getRank(); - auto opSplitAxes = op.getSplitAxes().getAxes(); - auto mesh = op.getMesh(); + auto opSplitAxes = adaptor.getSplitAxes().getAxes(); + auto mesh = adaptor.getMesh(); auto meshOp = getMesh(op, symbolTableCollection); - auto haloSizes = - getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter); // subviews need Index values for (auto &sz : haloSizes) { - if (auto value = dyn_cast(sz)) { + if (auto value = dyn_cast(sz)) sz = rewriter .create(loc, rewriter.getIndexType(), value) .getResult(); - } } // most of the offset/size/stride data is the same for all dims @@ -282,11 +596,10 @@ struct ConvertUpdateHaloOp // we need the actual shape to compute offsets and sizes for (auto i = 0; i < rank; ++i) { auto s = dstShape[i]; - if (ShapedType::isDynamic(s)) { + if (ShapedType::isDynamic(s)) shape[i] = rewriter.create(loc, array, s).getResult(); - } else { + else shape[i] = rewriter.getIndexAttr(s); - } if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { ++currHaloDim; @@ -294,11 +607,9 @@ struct ConvertUpdateHaloOp offsets[i] = haloSizes[currHaloDim * 2]; // prepare shape and offsets of highest dim's halo exchange - auto _haloSz = - rewriter - .create(loc, toValue(haloSizes[currHaloDim * 2]), - toValue(haloSizes[currHaloDim * 2 + 1])) - .getResult(); + Value _haloSz = rewriter.create( + loc, toValue(haloSizes[currHaloDim * 2]), + toValue(haloSizes[currHaloDim * 2 + 1])); // the halo shape of lower dims exlude the halos dimSizes[i] = rewriter.create(loc, toValue(shape[i]), _haloSz) @@ -309,9 +620,9 @@ struct ConvertUpdateHaloOp } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something - auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr); + auto tag = rewriter.create(loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 - auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr); + auto zero = rewriter.create(loc, zeroAttr); SmallVector indexResultTypes(meshOp.getShape().size(), rewriter.getIndexType()); @@ -321,9 +632,8 @@ struct ConvertUpdateHaloOp // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { auto splitAxes = opSplitAxes[dim]; - if (splitAxes.empty()) { + if (splitAxes.empty()) continue; - } assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split @@ -356,8 +666,8 @@ struct ConvertUpdateHaloOp : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive // Processes on the mesh borders have only one neighbor - auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; + auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; + auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto hasFrom = rewriter.create( loc, arith::CmpIPredicate::sge, from, zero); auto hasTo = rewriter.create( @@ -390,8 +700,24 @@ struct ConvertUpdateHaloOp offsets[dim] = orgOffset; }; - genSendRecv(false); - genSendRecv(true); + auto doSendRecv = [&](int upOrDown) { + OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; + Value haloSz = dyn_cast(v); + if (!haloSz) + haloSz = rewriter.create( + loc, rewriter.getI32IntegerAttr( + cast(cast(v)).getInt())); + auto hasSize = rewriter.create( + loc, arith::CmpIPredicate::sgt, haloSz, zero); + rewriter.create(loc, hasSize, + [&](OpBuilder &builder, Location loc) { + genSendRecv(upOrDown > 0); + builder.create(loc); + }); + }; + + doSendRecv(0); + doSendRecv(1); // the shape for lower dims include higher dims' halos dimSizes[dim] = shape[dim]; @@ -409,7 +735,7 @@ struct ConvertUpdateHaloOp loc, op.getResult().getType(), array, /*restrict=*/true, /*writable=*/true)); } - return mlir::success(); + return success(); } }; @@ -419,14 +745,95 @@ struct ConvertMeshToMPIPass /// Run the dialect converter on the module. void runOnOperation() override { - auto *ctx = &getContext(); - mlir::RewritePatternSet patterns(ctx); + uint64_t worldRank = -1; + // Try to get DLTI attribute for MPI:comm_world_rank + // If found, set worldRank to the value of the attribute. + { + auto dltiAttr = + dlti::query(getOperation(), {"MPI:comm_world_rank"}, false); + if (succeeded(dltiAttr)) { + if (!isa(dltiAttr.value())) { + getOperation()->emitError() + << "Expected an integer attribute for MPI:comm_world_rank"; + return signalPassFailure(); + } + worldRank = cast(dltiAttr.value()).getInt(); + } + } - patterns.insert( - ctx); + auto *ctxt = &getContext(); + RewritePatternSet patterns(ctxt); + ConversionTarget target(getContext()); + + // Define a type converter to convert mesh::ShardingType, + // mostly for use in return operations. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + // convert mesh::ShardingType to a tuple of RankedTensorTypes + typeConverter.addConversion( + [](ShardingType type, + SmallVectorImpl &results) -> std::optional { + auto i16 = IntegerType::get(type.getContext(), 16); + auto i64 = IntegerType::get(type.getContext(), 64); + std::array shp = {ShapedType::kDynamic, + ShapedType::kDynamic}; + results.emplace_back(RankedTensorType::get(shp, i16)); + results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2 + results.emplace_back(RankedTensorType::get(shp, i64)); + return success(); + }); + + // To 'extract' components, a UnrealizedConversionCastOp is expected + // to define the input + typeConverter.addTargetMaterialization( + [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc) { + // Expecting a single input. + if (inputs.size() != 1 || !isa(inputs[0].getType())) + return SmallVector(); + auto castOp = inputs[0].getDefiningOp(); + // Expecting an UnrealizedConversionCastOp. + if (!castOp) + return SmallVector(); + // Fill a vector with elements of the tuple/castOp. + SmallVector results; + for (auto oprnd : castOp.getInputs()) { + if (!isa(oprnd.getType())) + return SmallVector(); + results.emplace_back(oprnd); + } + return results; + }); - (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); + // No mesh dialect should left after conversion... + target.addIllegalDialect(); + // ...except the global MeshOp + target.addLegalOp(); + // Allow all the stuff that our patterns will convert to + target.addLegalDialect(); + // Make sure the function signature, calls etc. are legal + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + patterns.add(typeConverter, ctxt); + // ConvertProcessLinearIndexOp accepts an optional worldRank + patterns.add(typeConverter, ctxt, worldRank); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + + (void)applyPartialConversion(getOperation(), target, std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 304ede195c762..3e9f86fde64f3 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -831,12 +831,19 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, // mesh.shard_shape //===----------------------------------------------------------------------===// +void ShardShapeOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult()[0], "shard_shape"); +} + void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, - ::llvm::ArrayRef shape, - ::mlir::Value sharding, ::mlir::Value device) { - SmallVector resType(shape.size(), odsBuilder.getIndexType()); - build(odsBuilder, odsState, resType, shape, sharding, device); + ::llvm::ArrayRef dims, + ArrayRef dims_dyn, ::mlir::Value sharding, + ::mlir::ValueRange device) { + SmallVector resType(dims.size(), odsBuilder.getIndexType()); + build(odsBuilder, odsState, resType, dims, dims_dyn, sharding, + SmallVector(device.size(), ShapedType::kDynamic), device); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index b2acbf20b3fb9..b3d69eb5e1a23 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -50,10 +50,10 @@ struct CreatorOpShardingInterface IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const { - auto shardType = cast(mesh::shardType( - op->getResult(0).getType(), - mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable), - resultShardings[0])); + auto mesh = + mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable); + auto shardType = cast( + mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0])); Operation *newOp = nullptr; // if the sharding introduces a new dynamic dimension, we take it from // the dynamic sharding info. For now bail out if it's not @@ -66,18 +66,19 @@ struct CreatorOpShardingInterface assert(oldType.getRank() == shardType.getRank()); int currOldOprndNum = -1; mesh::ShardShapeOp shapeForDevice; - Value device; + ValueRange device; Operation *newSharding = nullptr; for (auto i = 0; i < oldType.getRank(); ++i) { if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) { if (!newSharding) { newSharding = builder.create(op->getLoc(), resultShardings[0]); - device = builder.create( - op->getLoc(), resultShardings[0].getMesh()); + device = + builder.create(op->getLoc(), mesh) + .getResults(); shapeForDevice = builder.create( - op->getLoc(), oldType.getShape(), newSharding->getResult(0), - device); + op->getLoc(), oldType.getShape(), spmdizedOperands, + newSharding->getResult(0), device); } newOperands.emplace_back(shapeForDevice.getResult()[i]); } else if (oldType.isDynamicDim(i)) { diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index c1aef97438bd5..4e60c6f0d4e44 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { // ----- // CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 3x4x5) -memref.global constant @static_mpi_rank : memref = dense<24> -func.func @process_multi_index() -> (index, index, index) { - // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index - // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index - return %0#0, %0#1, %0#2 : index, index, index -} +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { + mesh.mesh @mesh0(shape = 3x4x5) + func.func @process_multi_index() -> (index, index, index) { + // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index + return %0#0, %0#1, %0#2 : index, index, index + } -// CHECK-LABEL: func @process_linear_index -func.func @process_linear_index() -> index { - // CHECK: %[[c24:.*]] = arith.constant 24 : index - %0 = mesh.process_linear_index on @mesh0 : index - // CHECK: return %[[c24]] : index - return %0 : index + // CHECK-LABEL: func @process_linear_index + func.func @process_linear_index() -> index { + // CHECK: %[[c24:.*]] = arith.constant 24 : index + %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: return %[[c24]] : index + return %0 : index + } } // ----- @@ -103,106 +104,200 @@ func.func @update_halo_1d_first( } // ----- -mesh.mesh @mesh0(shape = 3x4x5) -memref.global constant @static_mpi_rank : memref = dense<24> -// CHECK-LABEL: func @update_halo_3d -func.func @update_halo_3d( - // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> - %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { - // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 - // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> - // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> - // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> - // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> - // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> - // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> - // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> - // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> - // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> - // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> - // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> - // CHECK: return [[varg0]] : memref<120x120x120xi8> - return %res : memref<120x120x120xi8> +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { + mesh.mesh @mesh0(shape = 4) + // CHECK-LABEL: func @update_halo_1d_with_zero + func.func @update_halo_1d_with_zero ( + // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8 + // CHECK-SAME: to memref<2x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8 + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> + // CHECK: return [[res:%.*]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> + } +} + +// ----- +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { + mesh.mesh @mesh0(shape = 3x4x5) + // CHECK-LABEL: func @update_halo_3d + func.func @update_halo_3d( + // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> + %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { + // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> + // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> + // CHECK: return [[varg0]] : memref<120x120x120xi8> + return %res : memref<120x120x120xi8> + } + + // CHECK-LABEL: func @update_halo_3d_tensor + func.func @update_halo_3d_tensor( + // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> + %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { + // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> + // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> + // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> + // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> + // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> + // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> + %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> + // CHECK: return [[v1]] : tensor<120x120x120xi8> + return %res : tensor<120x120x120xi8> + } +} + +// ----- +mesh.mesh @mesh0(shape = 2x2x4) +// CHECK-LABEL: func.func @return_sharding( +// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor, tensor, tensor) { +func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) { + %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor + // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor + // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor, tensor, tensor + return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding +} + +// CHECK-LABEL: func.func @return_sharding_halos( +// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor, tensor, tensor) { +func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) { + %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> + // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor + // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor + // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor, tensor, tensor + return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding } -// CHECK-LABEL: func @update_halo_3d_tensor -func.func @update_halo_3d_tensor( - // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> - %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { - // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32 - // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 - // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 - // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> - // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> - // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> - // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8> - // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8> - // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8> - // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> - // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8> - // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> - // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> - // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8> - // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> - // CHECK: return [[v1]] : tensor<120x120x120xi8> - return %res : tensor<120x120x120xi8> +// CHECK-LABEL: func.func @return_sharding_offs( +// CHECK-SAME: [[varg0:%.*]]: tensor) -> (tensor, tensor, tensor, tensor) { +func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !mesh.sharding) { + %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> + // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> + // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 + // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64> + // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64> + // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> + // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> + // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor + // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor + // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor, tensor, tensor, tensor + return %arg0, %sharding : tensor, !mesh.sharding } diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir new file mode 100644 index 0000000000000..156bbfb54845b --- /dev/null +++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s + +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { + + // CHECK: mesh.mesh @mesh0 + mesh.mesh @mesh0(shape = 3x4x5) + + // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0 + + // all shards are equal + // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) { + func.func @shard_shape_equal() -> (index, index, index) { + %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding + %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // last shard in last dim gets an extra element + // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) { + func.func @shard_shape_odd_1() -> (index, index, index) { + %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding + %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index + %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // In the second dimension the shard sizes are now [3 4 4 4] + // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) { + func.func @shard_shape_odd_2() -> (index, index, index) { + %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding + %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %c9 = arith.constant 9 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // In the first dimension the shard sizes are now [3 4 4] + // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) { + func.func @shard_shape_odd_3() -> (index, index, index) { + %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding + %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index + %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // extract from sharded_dims_offsets + // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { + func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { + %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] + sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding + %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + // CHECK: [[vc2:%.*]] = arith.constant 2 : index + %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } +} diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 43a75bf3d8040..3d133f2255772 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -157,10 +157,12 @@ func.func @mesh_shard_shape() { %c3 = arith.constant 3 : index // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding - // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index - %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index - // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index - %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index + // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]] + // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]] + // CHECK-SAME: ] : index, index + %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index + // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index + %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index return } diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir index 5443eea83aa2d..01cf5972177f4 100644 --- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir +++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir @@ -10,8 +10,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () { %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding %sharded= mesh.shard %b to %sharding : tensor<8x16xf32> // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index + // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index + // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index // CHECK: tensor.empty(%[[V0]]#0) : tensor return @@ -24,8 +25,10 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding %sharded= mesh.shard %b to %sharding : tensor<8x?xf32> // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index + // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index + // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]] + // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor return