diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b496ee0114910..5a864865adffc 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { shard/partition sizes depend on the rank. }]; let dependentDialects = [ + "affine::AffineDialect", + "arith::ArithDialect", "memref::MemRefDialect", "mpi::MPIDialect", "scf::SCFDialect", diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h index f06b911ce3fe3..2b6743cd008c6 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h @@ -12,6 +12,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// // MPIDialect diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td index f2837e71df060..0c62a1794e19e 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; -def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ +def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [ MPI_OpNull, MPI_OpMax, MPI_OpMin, diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index d78aa92d201e7..935e0f785ef0c 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/MPI/IR/MPI.td" include "mlir/Dialect/MPI/IR/MPITypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" class MPI_Op traits = []> : Op; @@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> { // CommWorldOp //===----------------------------------------------------------------------===// -def MPI_CommWorldOp : MPI_Op<"comm_world", []> { +def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> { let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`"; let description = [{ This operation returns the predefined MPI_COMM_WORLD communicator. @@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> { // CommRankOp //===----------------------------------------------------------------------===// -def MPI_CommRankOp : MPI_Op<"comm_rank", []> { +def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> { let summary = "Get the current rank, equivalent to " "`MPI_Comm_rank(comm, &rank)`"; let description = [{ @@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { ); let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)"; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // CommSizeOp //===----------------------------------------------------------------------===// -def MPI_CommSizeOp : MPI_Op<"comm_size", []> { +def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> { let summary = "Get the size of the group associated to the communicator, " "equivalent to `MPI_Comm_size(comm, &size)`"; let description = [{ @@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { // CommSplitOp //===----------------------------------------------------------------------===// -def MPI_CommSplitOp : MPI_Op<"comm_split", []> { +def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> { let summary = "Partition the group associated with the given communicator into " "disjoint subgroups"; let description = [{ @@ -281,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, - MPI_OpClassEnum : $op, + MPI_ReductionOpEnum : $op, MPI_Comm : $comm ); diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 3878505f8f93f..c4d512b60bc51 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -212,6 +212,11 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder); +/// Converts a vector of OpFoldResults (ints) into vector of Values of the +/// provided type. +SmallVector getMixedAsValues(OpBuilder b, const Location &loc, + llvm::ArrayRef statics, + ValueRange dynamics, Type type = Type()); } // namespace mesh } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index f59c4c4c67517..ac05ee243d7be 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ ``` }]; let arguments = !con(commonArgs, (ins - AnyRankedTensor:$input, + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input, DefaultValuedAttr:$reduction )); let results = (outs - AnyRankedTensor:$result + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result ); let assemblyFormat = [{ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h index c64da29ca6412..3f1041cb25103 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h @@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns( auto isEndomorphismOp = [reduction](Operation *op, std::optional referenceOp) { auto allReduceOp = llvm::dyn_cast(op); - if (!allReduceOp || - allReduceOp.getInput().getType().getElementType() != - allReduceOp.getResult().getType().getElementType() || + auto inType = cast(allReduceOp.getInput().getType()); + auto outType = cast(allReduceOp.getResult().getType()); + if (!allReduceOp || inType.getElementType() != outType.getElementType() || allReduceOp.getReduction() != reduction) { return false; } @@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns( } auto refAllReduceOp = llvm::dyn_cast(referenceOp.value()); + auto refType = cast(refAllReduceOp.getResult().getType()); return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() && - allReduceOp.getInput().getType().getElementType() == - refAllReduceOp.getInput().getType().getElementType(); + inType.getElementType() == refType.getElementType(); }; auto isAlgebraicOp = [](Operation *op) { return static_cast(llvm::dyn_cast(op)); diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h index be82e2af399dc..f46c0db846088 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h @@ -42,6 +42,11 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, TypedValue createProcessLinearIndex(StringRef mesh, ArrayRef meshAxes, ImplicitLocOpBuilder &builder); +// Get process linear index from a multi-index along the given mesh axes . +TypedValue +createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, + ArrayRef meshAxes, + ImplicitLocOpBuilder &builder); } // namespace mesh } // namespace mlir diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 5575b295ae20a..d4deff5b88070 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -116,7 +116,7 @@ class MPIImplTraits { /// enum value. virtual Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, - mpi::MPI_OpClassEnum opAttr) = 0; + mpi::MPI_ReductionOpEnum opAttr) = 0; }; //===----------------------------------------------------------------------===// @@ -199,49 +199,49 @@ class MPICHImplTraits : public MPIImplTraits { } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, - mpi::MPI_OpClassEnum opAttr) override { + mpi::MPI_ReductionOpEnum opAttr) override { int32_t op = MPI_NO_OP; switch (opAttr) { - case mpi::MPI_OpClassEnum::MPI_OP_NULL: + case mpi::MPI_ReductionOpEnum::MPI_OP_NULL: op = MPI_NO_OP; break; - case mpi::MPI_OpClassEnum::MPI_MAX: + case mpi::MPI_ReductionOpEnum::MPI_MAX: op = MPI_MAX; break; - case mpi::MPI_OpClassEnum::MPI_MIN: + case mpi::MPI_ReductionOpEnum::MPI_MIN: op = MPI_MIN; break; - case mpi::MPI_OpClassEnum::MPI_SUM: + case mpi::MPI_ReductionOpEnum::MPI_SUM: op = MPI_SUM; break; - case mpi::MPI_OpClassEnum::MPI_PROD: + case mpi::MPI_ReductionOpEnum::MPI_PROD: op = MPI_PROD; break; - case mpi::MPI_OpClassEnum::MPI_LAND: + case mpi::MPI_ReductionOpEnum::MPI_LAND: op = MPI_LAND; break; - case mpi::MPI_OpClassEnum::MPI_BAND: + case mpi::MPI_ReductionOpEnum::MPI_BAND: op = MPI_BAND; break; - case mpi::MPI_OpClassEnum::MPI_LOR: + case mpi::MPI_ReductionOpEnum::MPI_LOR: op = MPI_LOR; break; - case mpi::MPI_OpClassEnum::MPI_BOR: + case mpi::MPI_ReductionOpEnum::MPI_BOR: op = MPI_BOR; break; - case mpi::MPI_OpClassEnum::MPI_LXOR: + case mpi::MPI_ReductionOpEnum::MPI_LXOR: op = MPI_LXOR; break; - case mpi::MPI_OpClassEnum::MPI_BXOR: + case mpi::MPI_ReductionOpEnum::MPI_BXOR: op = MPI_BXOR; break; - case mpi::MPI_OpClassEnum::MPI_MINLOC: + case mpi::MPI_ReductionOpEnum::MPI_MINLOC: op = MPI_MINLOC; break; - case mpi::MPI_OpClassEnum::MPI_MAXLOC: + case mpi::MPI_ReductionOpEnum::MPI_MAXLOC: op = MPI_MAXLOC; break; - case mpi::MPI_OpClassEnum::MPI_REPLACE: + case mpi::MPI_ReductionOpEnum::MPI_REPLACE: op = MPI_REPLACE; break; } @@ -336,49 +336,49 @@ class OMPIImplTraits : public MPIImplTraits { } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, - mpi::MPI_OpClassEnum opAttr) override { + mpi::MPI_ReductionOpEnum opAttr) override { StringRef op; switch (opAttr) { - case mpi::MPI_OpClassEnum::MPI_OP_NULL: + case mpi::MPI_ReductionOpEnum::MPI_OP_NULL: op = "ompi_mpi_no_op"; break; - case mpi::MPI_OpClassEnum::MPI_MAX: + case mpi::MPI_ReductionOpEnum::MPI_MAX: op = "ompi_mpi_max"; break; - case mpi::MPI_OpClassEnum::MPI_MIN: + case mpi::MPI_ReductionOpEnum::MPI_MIN: op = "ompi_mpi_min"; break; - case mpi::MPI_OpClassEnum::MPI_SUM: + case mpi::MPI_ReductionOpEnum::MPI_SUM: op = "ompi_mpi_sum"; break; - case mpi::MPI_OpClassEnum::MPI_PROD: + case mpi::MPI_ReductionOpEnum::MPI_PROD: op = "ompi_mpi_prod"; break; - case mpi::MPI_OpClassEnum::MPI_LAND: + case mpi::MPI_ReductionOpEnum::MPI_LAND: op = "ompi_mpi_land"; break; - case mpi::MPI_OpClassEnum::MPI_BAND: + case mpi::MPI_ReductionOpEnum::MPI_BAND: op = "ompi_mpi_band"; break; - case mpi::MPI_OpClassEnum::MPI_LOR: + case mpi::MPI_ReductionOpEnum::MPI_LOR: op = "ompi_mpi_lor"; break; - case mpi::MPI_OpClassEnum::MPI_BOR: + case mpi::MPI_ReductionOpEnum::MPI_BOR: op = "ompi_mpi_bor"; break; - case mpi::MPI_OpClassEnum::MPI_LXOR: + case mpi::MPI_ReductionOpEnum::MPI_LXOR: op = "ompi_mpi_lxor"; break; - case mpi::MPI_OpClassEnum::MPI_BXOR: + case mpi::MPI_ReductionOpEnum::MPI_BXOR: op = "ompi_mpi_bxor"; break; - case mpi::MPI_OpClassEnum::MPI_MINLOC: + case mpi::MPI_ReductionOpEnum::MPI_MINLOC: op = "ompi_mpi_minloc"; break; - case mpi::MPI_OpClassEnum::MPI_MAXLOC: + case mpi::MPI_ReductionOpEnum::MPI_MAXLOC: op = "ompi_mpi_maxloc"; break; - case mpi::MPI_OpClassEnum::MPI_REPLACE: + case mpi::MPI_ReductionOpEnum::MPI_REPLACE: op = "ompi_mpi_replace"; break; } diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 823d4d644f586..aaf1d39d48438 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -12,9 +12,9 @@ #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -22,6 +22,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" @@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp class ConvertProcessLinearIndexOp : public OpConversionPattern { - int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0 public: using OpConversionPattern::OpConversionPattern; - // Constructor accepting worldRank - ConvertProcessLinearIndexOp(const TypeConverter &typeConverter, - MLIRContext *context, int64_t worldRank = -1) - : OpConversionPattern(typeConverter, context), worldRank(worldRank) {} - LogicalResult matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - + // Create mpi::CommRankOp Location loc = op.getLoc(); - if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it - rewriter.replaceOpWithNewOp(op, worldRank); - return success(); - } - - // Otherwise call create mpi::CommRankOp auto ctx = op.getContext(); Value commWorld = rewriter.create(loc, mpi::CommType::get(ctx)); @@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern { } }; +static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { + auto ctx = kind.getContext(); + switch (kind.getValue()) { + case ReductionKind::Sum: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM); + case ReductionKind::Product: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD); + case ReductionKind::Min: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN); + case ReductionKind::Max: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX); + case ReductionKind::BitwiseAnd: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND); + case ReductionKind::BitwiseOr: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR); + case ReductionKind::BitwiseXor: + return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR); + default: + assert(false && "Unknown/unsupported reduction kind"); + } +} + +struct ConvertAllReduceOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SymbolTableCollection symbolTableCollection; + auto mesh = adaptor.getMesh(); + mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); + if (!meshOp) + return op->emitError() << "No mesh found for AllReduceOp"; + if (ShapedType::isDynamicShape(meshOp.getShape())) + return op->emitError() + << "Dynamic mesh shape not supported in AllReduceOp"; + + ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); + Value input = adaptor.getInput(); + auto inputShape = cast(input.getType()).getShape(); + + // If the source is a memref, cast it to a tensor. + if (isa(input.getType())) { + auto memrefType = MemRefType::get( + inputShape, cast(input.getType()).getElementType()); + input = iBuilder.create(memrefType, input); + } + MemRefType inType = cast(input.getType()); + + // Get the actual shape to allocate the buffer. + SmallVector shape(inType.getRank()); + for (auto i = 0; i < inType.getRank(); ++i) { + auto s = inputShape[i]; + if (ShapedType::isDynamic(s)) + shape[i] = iBuilder.create(input, s).getResult(); + else + shape[i] = iBuilder.getIndexAttr(s); + } + + // Allocate buffer and copy input to buffer. + Value buffer = iBuilder.create( + shape, cast(op.getType()).getElementType()); + iBuilder.create(input, buffer); + + // Get an MPI_Comm_split for the AllReduce operation. + // The color is the linear index of the process in the mesh along the + // non-reduced axes. The key is the linear index of the process in the mesh + // along the reduced axes. + SmallVector indexResultTypes(meshOp.getShape().size(), + iBuilder.getIndexType()); + SmallVector myMultiIndex = + iBuilder.create(indexResultTypes, mesh) + .getResult(); + Value zero = iBuilder.create(0); + SmallVector multiKey(myMultiIndex.size(), zero); + + auto redAxes = adaptor.getMeshAxes(); + for (auto axis : redAxes) { + multiKey[axis] = myMultiIndex[axis]; + myMultiIndex[axis] = zero; + } + + Value color = + createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); + color = iBuilder.create(iBuilder.getI32Type(), color); + Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); + key = iBuilder.create(iBuilder.getI32Type(), key); + + // Finally split the communicator + auto commType = mpi::CommType::get(op->getContext()); + Value commWorld = iBuilder.create(commType); + auto comm = + iBuilder.create(commType, commWorld, color, key) + .getNewcomm(); + + Value buffer1d = buffer; + // Collapse shape to 1d if needed + if (inType.getRank() > 1) { + ReassociationIndices reassociation(inType.getRank()); + std::iota(reassociation.begin(), reassociation.end(), 0); + buffer1d = iBuilder.create( + buffer, ArrayRef(reassociation)); + } + + // Create the MPI AllReduce operation. + iBuilder.create( + TypeRange(), buffer1d, buffer1d, + getMPIReductionOp(adaptor.getReductionAttr()), comm); + + // If the destination is a memref, cast it to a tensor + if (isa(op.getType())) + buffer = iBuilder.create(buffer, true); + + rewriter.replaceOp(op, buffer); + return success(); + } +}; + struct ConvertUpdateHaloOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -573,10 +681,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { Value array = dest; if (isa(array.getType())) { // If the destination is a memref, we need to cast it to a tensor - auto tensorType = MemRefType::get( + auto mmemrefType = MemRefType::get( dstShape, cast(array.getType()).getElementType()); array = - rewriter.create(loc, tensorType, array); + rewriter.create(loc, mmemrefType, array); } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); @@ -753,22 +861,6 @@ struct ConvertMeshToMPIPass /// Run the dialect converter on the module. void runOnOperation() override { - uint64_t worldRank = -1; - // Try to get DLTI attribute for MPI:comm_world_rank - // If found, set worldRank to the value of the attribute. - { - auto dltiAttr = - dlti::query(getOperation(), {"MPI:comm_world_rank"}, false); - if (succeeded(dltiAttr)) { - if (!isa(dltiAttr.value())) { - getOperation()->emitError() - << "Expected an integer attribute for MPI:comm_world_rank"; - return signalPassFailure(); - } - worldRank = cast(dltiAttr.value()).getInt(); - } - } - auto *ctxt = &getContext(); RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); @@ -819,10 +911,10 @@ struct ConvertMeshToMPIPass // ...except the global MeshOp target.addLegalOp(); // Allow all the stuff that our patterns will convert to - target.addLegalDialect(); + target.addLegalDialect< + BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, + tensor::TensorDialect, bufferization::BufferizationDialect, + linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>(); // Make sure the function signature, calls etc. are legal target.addDynamicallyLegalOp([&](func::FuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()); @@ -832,9 +924,10 @@ struct ConvertMeshToMPIPass patterns.add(typeConverter, ctxt); - // ConvertProcessLinearIndexOp accepts an optional worldRank - patterns.add(typeConverter, ctxt, worldRank); + ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp, + ConvertProcessLinearIndexOp>(typeConverter, ctxt); + SymbolTableCollection symbolTableCollection; + mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 56d8edfbcc025..f2b6f97617c60 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" @@ -41,6 +42,35 @@ struct FoldCast final : public mlir::OpRewritePattern { return mlir::success(); } }; + +struct FoldRank final : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op, + mlir::PatternRewriter &b) const override { + auto comm = op.getComm(); + if (!comm.getDefiningOp()) + return mlir::failure(); + + // Try to get DLTI attribute for MPI:comm_world_rank + // If found, set worldRank to the value of the attribute. + auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false); + if (failed(dltiAttr)) + return mlir::failure(); + if (!isa(dltiAttr.value())) { + return op->emitError() + << "Expected an integer attribute for MPI:comm_world_rank"; + } + Value res = b.create( + op.getLoc(), cast(dltiAttr.value()).getInt()); + if (Value retVal = op.getRetval()) + b.replaceOp(op, {retVal, res}); + else + b.replaceOp(op, res); + return mlir::success(); + } +}; + } // namespace void mlir::mpi::SendOp::getCanonicalizationPatterns( @@ -63,6 +93,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns( results.add>(context); } +void mlir::mpi::CommRankOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index a2c2d1a7470cc..b8cc91da722f0 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -75,6 +75,29 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { return lhs.value() * rhs.value(); } +SmallVector mlir::mesh::getMixedAsValues(OpBuilder b, + const Location &loc, + llvm::ArrayRef statics, + ValueRange dynamics, + Type type) { + SmallVector values; + auto dyn = dynamics.begin(); + Type i64 = b.getI64Type(); + if (!type) + type = i64; + assert((i64 == type || b.getIndexType() == type) && + "expected an i64 or an intex type"); + for (auto s : statics) { + if (s == ShapedType::kDynamic) { + values.emplace_back(*(dyn++)); + } else { + TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); + values.emplace_back(b.create(loc, type, val)); + } + } + return values; +} + //===----------------------------------------------------------------------===// // Inliner //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index 447668cc0ea50..f08ef75d8a004 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -207,17 +207,27 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, builder.getIndexType())); } -TypedValue createProcessLinearIndex(StringRef mesh, - ArrayRef meshAxes, - ImplicitLocOpBuilder &builder) { - ResultRange processInGroupMultiIndex = - builder.create(mesh, meshAxes).getResults(); +TypedValue +createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, + ArrayRef meshAxes, + ImplicitLocOpBuilder &builder) { Operation::result_range processGroupShape = builder.create(mesh, meshAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); - return cast>(cast(processInGroupLinearIndex)); + auto res = dyn_cast(processInGroupLinearIndex); + if (!res) + res = builder.create( + cast(cast(processInGroupLinearIndex)).getInt()); + return cast>(res); } +TypedValue createProcessLinearIndex(StringRef mesh, + ArrayRef meshAxes, + ImplicitLocOpBuilder &builder) { + return createProcessLinearIndex( + mesh, builder.create(mesh, meshAxes).getResults(), + meshAxes, builder); +} } // namespace mlir::mesh diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index d314ad3ac30fd..d54d0034da5be 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -80,6 +80,63 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { } } +// ----- +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { + mesh.mesh @mesh0(shape = 3x4x5) + // CHECK-LABEL: func.func @allreduce_tensor( + func.func @allreduce_tensor( + // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> + %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32> + // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>) + // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm + // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> + // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> + // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32> + %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> + // CHECK: return [[v2]] : tensor<3x4xf32> + return %0 : tensor<3x4xf32> + } + + // CHECK-LABEL: func.func @allreduce_memref( + func.func @allreduce_memref( + // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> + %arg0 : memref<3x4xf32>) -> memref<3x4xf32> { + // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32> + // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>) + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm + // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> + // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> + %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> + // CHECK: return [[valloc]] : memref<3x4xf32> + return %0 : memref<3x4xf32> + } + + // CHECK-LABEL: func.func @allreduce_new_type( + func.func @allreduce_new_type( + // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> + %arg0 : memref<3x4xf32>) -> memref<3x4xf64> { + // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64> + // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>) + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm + // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64> + // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64> + %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> + // CHECK: return [[valloc]] : memref<3x4xf64> + return %0 : memref<3x4xf64> + } +} + // ----- mesh.mesh @mesh0(shape = 3x4x5) // CHECK-LABEL: func @update_halo_1d_first @@ -91,13 +148,13 @@ func.func @update_halo_1d_first( // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 // CHECK: mpi.recv( // CHECK-SAME: : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 + // CHECK: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 // CHECK: mpi.send( // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 // CHECK: mpi.recv( // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 + // CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> // CHECK: return [[res:%.*]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> @@ -110,18 +167,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { func.func @update_halo_1d_with_zero ( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { - // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 - // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32 - // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8> + // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-DAG: [[vc0_i32:%.*]] = arith.constant 0 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8> + // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 + // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 + // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> @@ -135,50 +192,50 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { func.func @update_halo_3d( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> { - // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32 - // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 - // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> - // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> - // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> - // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> - // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> - // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> - // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> - // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> - // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> - // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> + // CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 + // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 + // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 + // CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 + // CHECK: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> + // CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> + // CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32 + // CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> + // CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32 + // CHECK: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32 + // CHECK: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> + // CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> + // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32 + // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> @@ -188,54 +245,54 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { func.func @update_halo_3d_tensor( // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8> %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> { - // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32 - // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32 - // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 - // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 - // CHECK-NEXT: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> - // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> - // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> - // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> - // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> - // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> - // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> - // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> - // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> - // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> - // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> - // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> - // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> - // CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm - // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> - // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> - // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> - // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> - // CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> + // CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32 + // CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32 + // CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32 + // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> + // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> + // CHECK: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> + // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> + // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 + // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 + // CHECK: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> + // CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8> + // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> + // CHECK: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> + // CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> + // CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 + // CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 + // CHECK: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> + // CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> + // CHECK: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> + // CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> + // CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32 + // CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> + // CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> + // CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32 + // CHECK: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> + // CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK: [[v3:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> + // CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32 + // CHECK: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> + // CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> + // CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> + // CHECK: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> + // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> + // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32 + // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> + // CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> - // CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8> + // CHECK: return [[v4]] : tensor<120x120x120xi8> return %res : tensor<120x120x120xi8> } } @@ -246,19 +303,19 @@ mesh.mesh @mesh0(shape = 2x2x4) // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor, tensor, tensor) { func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) { %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 - // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> - // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> - // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64> - // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor - // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor - // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor - // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor, tensor, tensor + // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor + // CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor + // CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor, tensor, tensor return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding } @@ -266,19 +323,19 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor, tensor, tensor) { func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) { %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> - // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 - // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> - // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> - // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor - // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor - // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor - // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor, tensor, tensor + // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> + // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor + // CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor + // CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor, tensor, tensor return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding } @@ -286,24 +343,24 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m // CHECK-SAME: [[varg0:%.*]]: tensor) -> (tensor, tensor, tensor, tensor) { func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !mesh.sharding) { %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> - // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> - // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 - // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16 - // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> - // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> - // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> - // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64> - // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64> - // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> - // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> - // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor - // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor - // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor - // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor, tensor, tensor, tensor + // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> + // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> + // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 + // CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> + // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> + // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> + // CHECK: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64> + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64> + // CHECK: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64> + // CHECK: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> + // CHECK: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64> + // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor + // CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor + // CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor + // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor, tensor, tensor, tensor return %arg0, %sharding : tensor, !mesh.sharding }