diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md deleted file mode 100644 index 5eb6569c7044b..0000000000000 --- a/mlir/docs/Dialects/Mesh.md +++ /dev/null @@ -1,74 +0,0 @@ -# 'mesh' Dialect - -The `mesh` dialect contains a set of attributes, operations and interfaces that -are useful for representing sharding and communication on a device mesh -cluster. - -[TOC] - -## Collective Communication Operations -There are a number of operations in the Mesh dialect to facilitate -communication between devices in a mesh. -It is assumed that the user is familiar with collective operations. -[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good -explanation. -The main addition is that the collectives in this dialect have mesh -semantics. - -### Device groups -The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh -axes that partition the devices into disjoint groups. -The collective operation is performed between devices in the same group. -Devices that have the same coordinates outside of axes `mesh_axes` are in the -same group. -A group is described by its multi-index along the axes outside of `mesh_axes`. -For example if we have a device mesh of size `2x3x4x5` and the partition mesh -axes list is `[0, 1]` then devices are partitioned into the groups -`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`. -The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`. -Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. -Device (1, 0, 2, 4) will be in another group. -Some collective operations like all-to-all and all-gather care about the -order of devices. -The order of device in a device group is induced by the order of axes in -`mesh_axes`. -The axes are ordered from outer to inner. -If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede -both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`. - -### In-group Device -Some operations like `broadcast`, `scatter` and `send` specify devices in each -device-group. -These devices are represented with their multi-index over the mesh axes that -are not constant within a device group. -These are the axes specified by `mesh_axes` attribute. - -For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify -an in-group device with `(i, j)`. Then for each group with index `g` on the -second axis, the in-group device would be `(i, g, j)`. -### Purity -Collectives that involve the whole device group to perform a single operation -are pure. The exceptions are `send` and `recv`. - -There is an assumption that the execution is SPMD. -Not only that each process runs the same program, but that at the point of -execution of a collective operation, all processes are in a coherent state. -All compiler transformations must be consistent. -Collective operations in the IR that may correspond to the same runtime -collective operation must be transformed in a consistent manner. -For example if a collective operation is optimized out, than it must also -not appear in any path of execution on any process. - -Having the operations as `Pure` implies that if an interpreter is to execute -the IR containing the `mesh` collectives, all processes would execute the same -line when they reach a pure collective operation. -This requirement stems from the need to be compatible with general optimization -passes like dead code and common sub-expression elimination. - -## Operations - -[include "Dialects/MeshOps.md"] - -## Attributes - -[include "Dialects/MeshAttrs.md"] diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md new file mode 100644 index 0000000000000..eb6ff6150e474 --- /dev/null +++ b/mlir/docs/Dialects/Shard.md @@ -0,0 +1,92 @@ +# 'shard' Dialect + +The 'shard' dialect defines a set of attributes, operations, and interfaces for +working with tensor sharding and device communication. + +It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*). + +Originally, the dialect was called `mesh`, but it was renamed to better reflect +what it actually does. + +[TOC] + +## Collective Communication Operations + +The 'shard' dialect includes several collective operations that help coordinate +communication between devices arranged in a grid. + +If you’re not already familiar with collective operations, [this Wikipedia +article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting +point. + +Unlike traditional collectives that are defined in terms of message-passing +between explicit buffers on each process, the collectives in this dialect work +at a higher level. They’re defined in terms of how data moves across the +dimensions of a tensor, and the participating processes are inferred from how +the tensor is sharded - not specified manually. + +### Device Groups + +Each collective operation runs within a group of devices. You define groups +using the `grid` and `grid_axes` attributes, which describe how to slice the +full device grid into smaller groups. + +Devices that have the same coordinates *outside* the listed `grid_axes` belong +to the same group. + +Example: Say your device grid is shaped `2×3×4×5`, and you set +`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like: + +``` +{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 } +``` + +So the groups are identified by the coordinates `(k, m)`, and devices like +`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)` +is in a different group. + +For some collectives (like `all-to-all`), the order of devices in the group +matters. The device order is based on the order of axes in `grid_axes`, from +outermost to innermost. + +Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before +`(i, 0, k, 1)` and `(i, 2, k, 0)`. + +### In-group Devices + +Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific +device within each group. These in-group devices are identified using their +coordinates over the axes listed in `grid_axes`. + +Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified +as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full +device index would be `(i, g, j)`. + +### Purity and Execution Model + +Collective operations involve all devices in a group (e.g. `all-gather`, +`all-to-all`) and are considered pure. Operations like `send` and `recv` are not +collective and are not pure. + +The execution model assumes SPMD (Single Program, Multiple Data): + +* Every process runs the same program. +* At any collective operation, all processes are in sync. + +This means compiler optimizations must treat collective ops carefully. For +example, if a collective is removed during optimization, it must be removed from +*every* path and *every* process that would have participated - otherwise, you’ll +get undefined behavior at runtime. + +Marking these ops as pure also helps with standard compiler passes like dead +code elimination and common subexpression elimination. It ensures that when the +program is executed, all devices hit the same line of code at the same time +during collectives and so avoid dead-locks. + +## Operations + +[include "Dialects/ShardOps.md"] + +## Attributes + +[include "Dialects/ShardAttrs.md"] diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md index e9d22d1e3dfac..9df32666415bb 100644 --- a/mlir/docs/Passes.md +++ b/mlir/docs/Passes.md @@ -72,9 +72,9 @@ This document describes the available MLIR passes and their contracts. [include "MemRefPasses.md"] -## 'mesh' Dialect Passes +## 'shard' Dialect Passes -[include "MeshPasses.md"] +[include "ShardPasses.md"] ## 'ml\_program' Dialect Passes diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index d93fbefab74aa..3dc48b2201cf2 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -52,7 +52,6 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" @@ -66,6 +65,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" #include "mlir/Conversion/TosaToArith/TosaToArith.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 8183f355795a9..eb18160ea2eeb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -903,13 +903,13 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> { } //===----------------------------------------------------------------------===// -// MeshToMPI +// ShardToMPI //===----------------------------------------------------------------------===// -def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> { - let summary = "Convert Mesh dialect to MPI dialect."; +def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> { + let summary = "Convert Shard dialect to MPI dialect."; let description = [{ - This pass converts communication operations from the Mesh dialect to the + This pass converts communication operations from the Shard dialect to the MPI dialect. If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will use that integer value instead of calling MPI_Comm_rank. This allows diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h similarity index 64% rename from mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h rename to mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h index bc64e7a3c1c8c..b1aa08c432249 100644 --- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h +++ b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h @@ -1,4 +1,4 @@ -//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===// +//===- ShardToMPI.h - Convert Shard to MPI dialect --------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H -#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H +#ifndef MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H +#define MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -15,9 +15,9 @@ namespace mlir { class Pass; -#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS +#define GEN_PASS_DECL_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir -#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H +#endif // MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 56dc97282fa4a..e27b1679c2a52 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -19,7 +19,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(Math) add_subdirectory(MemRef) -add_subdirectory(Mesh) +add_subdirectory(Shard) add_subdirectory(MLProgram) add_subdirectory(MPI) add_subdirectory(NVGPU) diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h similarity index 88% rename from mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h rename to mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h index 30d3033209d21..e22b24b3446bb 100644 --- a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h +++ b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h @@ -1,4 +1,4 @@ -//===- MeshShardingExtensions.h - -----------------------------------------===// +//===- ShardingExtensions.h - -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h similarity index 54% rename from mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h rename to mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h index c57501ea86b7e..dc21bc05a2dc1 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h @@ -1,4 +1,4 @@ -//===- MeshShardingInterfaceImpl.h ----------------------------------------===// +//===- ShardingInterfaceImpl.h ----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H -#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H +#ifndef MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H namespace mlir { class DialectRegistry; namespace linalg { -void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry); +void registerShardingInterfaceExternalModels(DialectRegistry ®istry); } // namespace linalg } // namespace mlir -#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H +#endif // MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt deleted file mode 100644 index f26c6285efd89..0000000000000 --- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh) -add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh) - -set(LLVM_TARGET_DEFINITIONS MeshOps.td) -mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh) -mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh) - -set(LLVM_TARGET_DEFINITIONS MeshBase.td) -mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls) -mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs) - -set(LLVM_TARGET_DEFINITIONS MeshBase.td) -mlir_tablegen(MeshEnums.h.inc -gen-enum-decls) -mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs) - -set(LLVM_TARGET_DEFINITIONS MeshBase.td) -mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls) -mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs) - -set(LLVM_TARGET_DEFINITIONS MeshOps.td) -mlir_tablegen(MeshOps.h.inc -gen-op-decls) -mlir_tablegen(MeshOps.cpp.inc -gen-op-defs) - -add_public_tablegen_target(MLIRMeshIncGen) -add_dependencies(mlir-headers MLIRMeshIncGen) diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt deleted file mode 100644 index 8d768485103b6..0000000000000 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh) -add_public_tablegen_target(MLIRMeshPassIncGen) -add_dependencies(mlir-headers MLIRMeshPassIncGen) - -add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/CMakeLists.txt similarity index 100% rename from mlir/include/mlir/Dialect/Mesh/CMakeLists.txt rename to mlir/include/mlir/Dialect/Shard/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt new file mode 100644 index 0000000000000..a2495af135899 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt @@ -0,0 +1,25 @@ +add_mlir_doc(ShardOps ShardOps Dialects/ -gen-op-doc -dialect=shard) +add_mlir_doc(ShardOps ShardAttrs Dialects/ -gen-attrdef-doc -dialect=shard) + +set(LLVM_TARGET_DEFINITIONS ShardOps.td) +mlir_tablegen(ShardDialect.cpp.inc -gen-dialect-defs -dialect=shard) +mlir_tablegen(ShardDialect.h.inc -gen-dialect-decls -dialect=shard) + +set(LLVM_TARGET_DEFINITIONS ShardBase.td) +mlir_tablegen(ShardAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(ShardAttributes.cpp.inc -gen-attrdef-defs) + +set(LLVM_TARGET_DEFINITIONS ShardBase.td) +mlir_tablegen(ShardEnums.h.inc -gen-enum-decls) +mlir_tablegen(ShardEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS ShardBase.td) +mlir_tablegen(ShardTypes.h.inc -gen-typedef-decls) +mlir_tablegen(ShardTypes.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS ShardOps.td) +mlir_tablegen(ShardOps.h.inc -gen-op-decls) +mlir_tablegen(ShardOps.cpp.inc -gen-op-defs) + +add_public_tablegen_target(MLIRShardIncGen) +add_dependencies(mlir-headers MLIRShardIncGen) diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td similarity index 64% rename from mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td rename to mlir/include/mlir/Dialect/Shard/IR/ShardBase.td index 61403ac178980..41ae31807c825 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td @@ -1,4 +1,4 @@ -//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===// +//===- ShardBase.td - Shard Dialect ------------------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD -#define MLIR_DIALECT_MESH_IR_MESHBASE_TD +#ifndef MLIR_DIALECT_SHARD_IR_SHARDBASE_TD +#define MLIR_DIALECT_SHARD_IR_SHARDBASE_TD include "mlir/IR/OpBase.td" include "mlir/IR/AttrTypeBase.td" @@ -16,15 +16,15 @@ include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/EnumAttr.td" //===----------------------------------------------------------------------===// -// Mesh Dialect +// Shard Dialect //===----------------------------------------------------------------------===// -def Mesh_Dialect : Dialect { - let name = "mesh"; - let cppNamespace = "::mlir::mesh"; +def Shard_Dialect : Dialect { + let name = "shard"; + let cppNamespace = "::mlir::shard"; let description = [{ - See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md). + See [Shard dialect documentation](mlir/docs/Dialects/Shard.md). }]; let dependentDialects = [ @@ -36,16 +36,16 @@ def Mesh_Dialect : Dialect { let hasConstantMaterializer = 1; } -def Mesh_MeshAxis : I<16>; -def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; -def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; +def Shard_GridAxis : I<16>; +def Shard_GridAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; +def Shard_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; //===----------------------------------------------------------------------===// -// Mesh Enums. +// Shard Enums. //===----------------------------------------------------------------------===// -def Mesh_ReductionKind : I32EnumAttr<"ReductionKind", - "Reduction of an iterator/mesh dimension.", [ +def Shard_ReductionKind : I32EnumAttr<"ReductionKind", + "Reduction of an iterator/grid dimension.", [ I32EnumAttrCase<"Sum", 1, "sum">, I32EnumAttrCase<"Max", 2, "max">, I32EnumAttrCase<"Min", 3, "min">, @@ -58,31 +58,31 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind", I32EnumAttrCase<"Generic", 100, "generic"> ]> { let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mesh"; + let cppNamespace = "::mlir::shard"; } -def Mesh_ReductionKindAttr : EnumAttr { +def Shard_ReductionKindAttr : EnumAttr { let assemblyFormat = "$value"; } -class Mesh_Type traits = [], +class Shard_Type traits = [], string baseCppClass = "::mlir::Type"> - : TypeDef { + : TypeDef { let mnemonic = typeMnemonic; } -def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> { +def Shard_Sharding : Shard_Type<"Sharding", "sharding"> { let summary = "sharding definition"; let assemblyFormat = ""; } //===----------------------------------------------------------------------===// -// Mesh Attribute +// Shard Attribute //===----------------------------------------------------------------------===// -def Mesh_MeshAxesArrayAttr : AttrDef { +def Shard_GridAxesArrayAttr : AttrDef { let mnemonic = "axisarray"; - let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes); + let parameters = (ins ArrayRefParameter<"GridAxesAttr">:$axes); let assemblyFormat = "`[` $axes `]`"; let extraClassDeclaration = [{ size_t size() const { return getAxes().size(); } @@ -91,4 +91,4 @@ def Mesh_MeshAxesArrayAttr : AttrDef { }]; } -#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD +#endif // MLIR_DIALECT_SHARD_IR_SHARDBASE_TD diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h similarity index 57% rename from mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h rename to mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h index a30cf91e851fe..4113a668d4b76 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h @@ -1,4 +1,4 @@ -//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===// +//===- ShardOps.h - Shard Dialect -------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H -#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H +#ifndef MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H +#define MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H #include "mlir/IR/Dialect.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc" +#include "mlir/Dialect/Shard/IR/ShardDialect.h.inc" -#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H +#endif // MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h similarity index 52% rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.h index 7cfe59dd957ca..457fe6f6b8d0a 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h @@ -1,4 +1,4 @@ -//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===// +//===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H -#define MLIR_DIALECT_MESH_IR_MESHOPS_H +#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H +#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -21,45 +21,45 @@ #include "llvm/Support/MathExtras.h" namespace mlir { -namespace mesh { +namespace shard { -using MeshAxis = int16_t; -using MeshAxesAttr = DenseI16ArrayAttr; +using GridAxis = int16_t; +using GridAxesAttr = DenseI16ArrayAttr; using ShardShapeAttr = DenseI64ArrayAttr; using HaloSizePairAttr = DenseI64ArrayAttr; -} // namespace mesh +} // namespace shard } // namespace mlir -#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc" +#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc" #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc" +#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc" namespace mlir { -namespace mesh { +namespace shard { -class MeshSharding { +class Sharding { private: - ::mlir::FlatSymbolRefAttr mesh; - SmallVector split_axes; + ::mlir::FlatSymbolRefAttr grid; + SmallVector split_axes; SmallVector static_halo_sizes; SmallVector static_sharded_dims_offsets; SmallVector dynamic_halo_sizes; SmallVector dynamic_sharded_dims_offsets; public: - MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr); - MeshSharding(Value rhs); - static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, - ArrayRef split_axes_, - ArrayRef static_halo_sizes_ = {}, - ArrayRef static_sharded_dims_offsets_ = {}, - ArrayRef dynamic_halo_sizes_ = {}, - ArrayRef dynamic_sharded_dims_offsets_ = {}); - ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; } - ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; } - ArrayRef getSplitAxes() const { return split_axes; } + Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr); + Sharding(Value rhs); + static Sharding get(::mlir::FlatSymbolRefAttr grid_, + ArrayRef split_axes_, + ArrayRef static_halo_sizes_ = {}, + ArrayRef static_sharded_dims_offsets_ = {}, + ArrayRef dynamic_halo_sizes_ = {}, + ArrayRef dynamic_sharded_dims_offsets_ = {}); + ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; } + ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; } + ArrayRef getSplitAxes() const { return split_axes; } ArrayRef getStaticHaloSizes() const { return static_halo_sizes; } ArrayRef getStaticShardedDimsOffsets() const { return static_sharded_dims_offsets; @@ -68,28 +68,28 @@ class MeshSharding { ArrayRef getDynamicShardedDimsOffsets() const { return dynamic_sharded_dims_offsets; } - operator bool() const { return (!mesh) == false; } + operator bool() const { return (!grid) == false; } bool operator==(Value rhs) const; bool operator!=(Value rhs) const; - bool operator==(const MeshSharding &rhs) const; - bool operator!=(const MeshSharding &rhs) const; - bool equalSplitAxes(const MeshSharding &rhs) const; - bool equalHaloAndShardSizes(const MeshSharding &rhs) const; - bool equalHaloSizes(const MeshSharding &rhs) const; - bool equalShardSizes(const MeshSharding &rhs) const; + bool operator==(const Sharding &rhs) const; + bool operator!=(const Sharding &rhs) const; + bool equalSplitAxes(const Sharding &rhs) const; + bool equalHaloAndShardSizes(const Sharding &rhs) const; + bool equalHaloSizes(const Sharding &rhs) const; + bool equalShardSizes(const Sharding &rhs) const; }; -} // namespace mesh +} // namespace shard } // namespace mlir #define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc" +#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc" #define GET_OP_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc" +#include "mlir/Dialect/Shard/IR/ShardOps.h.inc" namespace mlir { -namespace mesh { +namespace shard { inline bool isReductionLoop(utils::IteratorType iType) { return iType == utils::IteratorType::reduction; @@ -103,52 +103,52 @@ void removeTrailingEmptySubArray(SmallVector> &array) { } // Is the same tensor replicated on all processes. -inline bool isFullReplication(MeshSharding sharding) { - return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) { +inline bool isFullReplication(Sharding sharding) { + return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) { return axes.asArrayRef().empty(); }); } -inline mesh::MeshOp -getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, +inline shard::GridOp +getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) { - if (!meshSymbol) + if (!gridSymbol) return nullptr; - return symbolTableCollection.lookupNearestSymbolFrom( - op, meshSymbol); + return symbolTableCollection.lookupNearestSymbolFrom( + op, gridSymbol); } -inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, - SymbolTableCollection &symbolTableCollection) { - mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection); - assert(meshOp); - return meshOp; +inline shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, + SymbolTableCollection &symbolTableCollection) { + shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection); + assert(gridOp); + return gridOp; } -// Get the corresponding mesh op using the standard attribute nomenclature. +// Get the corresponding grid op using the standard attribute nomenclature. template -mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) { - return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection); +shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) { + return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection); } template <> -inline mesh::MeshOp -getMesh(ShardOp op, SymbolTableCollection &symbolTableCollection) { - return getMesh( +inline shard::GridOp +getGrid(ShardOp op, SymbolTableCollection &symbolTableCollection) { + return getGrid( op.getOperation(), - cast(op.getSharding().getDefiningOp()).getMeshAttr(), + cast(op.getSharding().getDefiningOp()).getGridAttr(), symbolTableCollection); } // Get the number of processes that participate in each group -// induced by `meshAxes`. -template -int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, - MeshShapeRange &&meshShape) { +// induced by `gridAxes`. +template +int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, + GridShapeRange &&gridShape) { int64_t res = 1; - for (MeshAxis axis : meshAxes) { - auto axisSize = *(std::begin(meshShape) + axis); + for (GridAxis axis : gridAxes) { + auto axisSize = *(std::begin(gridShape) + axis); if (ShapedType::isDynamic(axisSize)) { return ShapedType::kDynamic; } @@ -158,10 +158,10 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, return res; } -template -int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) { - return collectiveProcessGroupSize(std::forward(meshAxes), - mesh.getShape()); +template +int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) { + return collectiveProcessGroupSize(std::forward(gridAxes), + grid.getShape()); } // Get the size of a sharded dimension. @@ -182,27 +182,25 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) { } // Return the sharded shape `shape` according ot sharding `sharding`. -// The shape for the tensor on each device in the mesh. +// The shape for the tensor on each device in the grid. // Example: -// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would +// On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would // result in a shape for each shard of ?x2x?. -ShapedType shardShapedType(ShapedType shape, MeshOp mesh, - MeshSharding sharding); +ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding); // If ranked tensor type return its sharded counterpart. // // If not ranked tensor type return `type`. // `sharding` in that case must be null. -Type shardType(Type type, MeshOp mesh, MeshSharding sharding); +Type shardType(Type type, GridOp grid, Sharding sharding); // Insert shard op if there is not one that already has the same sharding. // Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. // Potentially updates newShardOp. -void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, +void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder); -void maybeInsertSourceShardingAnnotation(MeshSharding sharding, - OpOperand &operand, +void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder); /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -210,7 +208,7 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding, SmallVector getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef statics, ValueRange dynamics, Type type = Type()); -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H +#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td similarity index 70% rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 1662885c161e6..29b384f401876 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -1,4 +1,4 @@ -//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===// +//===-- ShardOps.td - Shard dialect operation definitions ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD -#define MLIR_DIALECT_MESH_IR_MESHOPS_TD +#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_TD +#define MLIR_DIALECT_SHARD_IR_SHARDOPS_TD -include "mlir/Dialect/Mesh/IR/MeshBase.td" +include "mlir/Dialect/Shard/IR/ShardBase.td" include "mlir/Dialect/Shape/IR/ShapeBase.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -21,24 +21,24 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// -// Mesh operations. +// Shard operations. //===----------------------------------------------------------------------===// -class Mesh_Op traits = []> : - Op { +class Shard_Op traits = []> : + Op { } -def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> { - let summary = "Description of a device/process mesh."; +def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> { + let summary = "Description of a device/process grid."; let description = [{ - The mesh.mesh operation is a symbol operation that identifies a specific - mesh. The operation has three attributes: + The shard.grid operation is a symbol operation that identifies a specific + grid. The operation has three attributes: - 1. `sym_name`: This attribute uniquely identifies the name of the mesh. - This name serves as a symbolic reference to the mesh throughout + 1. `sym_name`: This attribute uniquely identifies the name of the grid. + This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging. - 2. `shape`: This attribute represents the shape of the device mesh. + 2. `shape`: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations @@ -48,21 +48,21 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> { Example: ``` - // A device mesh with 3 axes, the total device number is 4 * 8 * 12 + // A device grid with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 - mesh.mesh @mesh0(shape = 4x8x12) + shard.grid @grid0(shape = 4x8x12) - // A device mesh with 2 axes, the total device number is unknown + // A device grid with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown - mesh.mesh @mesh1(shape = 4x?) + shard.grid @grid1(shape = 4x?) - // A device mesh with 2 axes, the total device number is unknown + // A device grid with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 - mesh.mesh @mesh2(shape = ?x4) + shard.grid @grid2(shape = ?x4) - // A device mesh with 2 axes, the number of devices along both axes + // A device grid with 2 axes, the number of devices along both axes // is unknown - mesh.mesh @mesh3(shape = ?x?) + shard.grid @grid3(shape = ?x?) ``` }]; let arguments = (ins @@ -79,15 +79,15 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> { let hasVerifier = 1; } -def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [ +def Shard_GridShapeOp : Shard_Op<"grid_shape", [ Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { - let summary = "Get the shape of the mesh."; + let summary = "Get the shape of the grid."; let arguments = (ins - FlatSymbolRefAttr:$mesh, - DefaultValuedAttr:$axes + FlatSymbolRefAttr:$grid, + DefaultValuedAttr:$axes ); let results = (outs @@ -95,46 +95,46 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [ ); let assemblyFormat = [{ - $mesh (`axes` `=` $axes^)? + $grid (`axes` `=` $axes^)? attr-dict `:` type($result) }]; let builders = [ - OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, - OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef":$axes)>, - OpBuilder<(ins "StringRef":$mesh, "ArrayRef":$axes)> + OpBuilder<(ins "::mlir::shard::GridOp":$grid)>, + OpBuilder<(ins "::mlir::shard::GridOp":$grid, "ArrayRef":$axes)>, + OpBuilder<(ins "StringRef":$grid, "ArrayRef":$axes)> ]; } -def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [ +def Shard_ProcessMultiIndexOp : Shard_Op<"process_multi_index", [ Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { - let summary = "Get the multi index of current device along specified mesh axes."; + let summary = "Get the multi index of current device along specified grid axes."; let description = [{ It is used in the SPMD format of IR. - The `axes` mush be non-negative and less than the total number of mesh axes. + The `axes` mush be non-negative and less than the total number of grid axes. If the axes are empty then get the index along all axes. }]; let arguments = (ins - FlatSymbolRefAttr:$mesh, - DefaultValuedAttr:$axes + FlatSymbolRefAttr:$grid, + DefaultValuedAttr:$axes ); let results = (outs Variadic:$result ); let assemblyFormat = [{ - `on` $mesh (`axes` `=` $axes^)? + `on` $grid (`axes` `=` $axes^)? attr-dict `:` type($result) }]; let builders = [ - OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, - OpBuilder<(ins "StringRef":$mesh, "ArrayRef":$axes)> + OpBuilder<(ins "::mlir::shard::GridOp":$grid)>, + OpBuilder<(ins "StringRef":$grid, "ArrayRef":$axes)> ]; } -def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ +def Shard_ProcessLinearIndexOp : Shard_Op<"process_linear_index", [ Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods @@ -143,34 +143,34 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ let description = [{ Example: ``` - %idx = mesh.process_linear_index on @mesh : index + %idx = shard.process_linear_index on @grid : index ``` - if `@mesh` has shape `(10, 20, 30)`, a device with multi + if `@grid` has shape `(10, 20, 30)`, a device with multi index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`. }]; - let arguments = (ins FlatSymbolRefAttr:$mesh); + let arguments = (ins FlatSymbolRefAttr:$grid); let results = (outs Index:$result); - let assemblyFormat = "`on` $mesh attr-dict `:` type($result)"; + let assemblyFormat = "`on` $grid attr-dict `:` type($result)"; let builders = [ - OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)> + OpBuilder<(ins "::mlir::shard::GridOp":$grid)> ]; } -def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ +def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [ Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { let summary = - "For given mesh index get the linear indices of the direct neighbor processes along the given split."; + "For given grid index get the linear indices of the direct neighbor processes along the given split."; let description = [{ Example: ``` - mesh.mesh @mesh0(shape = 10x20x30) + shard.grid @grid0(shape = 10x20x30) %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index - %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index + %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index ``` The above returns two indices, `633` and `693`, which correspond to the index of the previous process `(1, 1, 3)`, and the next process @@ -179,12 +179,12 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ A negative value is returned if there is no neighbor in the respective direction along the given `split_axes`. }]; - let arguments = (ins FlatSymbolRefAttr:$mesh, + let arguments = (ins FlatSymbolRefAttr:$grid, Variadic:$device, - Mesh_MeshAxesAttr:$split_axes); + Shard_GridAxesAttr:$split_axes); let results = (outs Index:$neighbor_down, Index:$neighbor_up); let assemblyFormat = [{ - `on` $mesh `[` $device `]` + `on` $grid `[` $device `]` `split_axes` `=` $split_axes attr-dict `:` type(results) }]; @@ -194,7 +194,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [ // Sharding operations. //===----------------------------------------------------------------------===// -def Mesh_ShardingOp : Mesh_Op<"sharding", [ +def Shard_ShardingOp : Shard_Op<"sharding", [ Pure, AttrSizedOperandSegments, DeclareOpInterfaceMethods, @@ -202,18 +202,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ ]> { let summary = "Define a sharding of a tensor."; let description = [{ - The MeshSharding specifies how a tensor is sharded and distributed across the - process mesh. It is typically used in a `mesh.shard` operation. + The Sharding specifies how a tensor is sharded and distributed across the + process shard. It is typically used in a `shard.shard` operation. The operation has the following attributes and operands: - 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device - mesh where the distributed tensor is placed. The symbol must resolve to a - `mesh.mesh` operation. + 1. `grid`: this attribute is a FlatSymbolRefAttr that refers to the device + grid where the distributed tensor is placed. The symbol must resolve to a + `shard.grid` operation. 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's maximum size is the `rank` of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor's i-th dimension is splitted - along the x and y axes of the device mesh. + along the x and y axes of the device grid. 3. [Optional] Sizes of halos to be added for each sharded tensor dimension. `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each @@ -233,7 +233,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of - the device-mesh will get a shard of shape 24x20x32 and the second device will get + the device-grid will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32. `?` indicates dynamic shard dimensions. `halo_sizes` and `sharded_dims_offsets` are mutually exclusive. @@ -241,101 +241,101 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [ Examples: ``` - mesh.mesh @mesh0(shape = 2x2x4) - mesh.mesh @mesh1d_4(shape = 4) + shard.grid @grid0(shape = 2x2x4) + shard.grid @grid1d_4(shape = 4) - // The tensor is fully replicated on @mesh0. + // The tensor is fully replicated on @grid0. // Currently, there must be at least one sub-array present in axes, even // if it's empty. Otherwise, a parsing error will occur. - %sharding0 = mesh.sharding @mesh0 split_axes = [[]] + %sharding0 = shard.sharding @grid0 split_axes = [[]] - // The tensor is sharded on the first dimension along axis 0 of @mesh0 - %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] + // The tensor is sharded on the first dimension along axis 0 of @grid0 + %sharding1 = shard.sharding @grid0 split_axes = [[0]] - // Could be used for a mesh.shard op - %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32> + // Could be used for a shard.shard op + %sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32> - // The tensor is sharded on its first dimension along axis 0 of @mesh0 and + // The tensor is sharded on its first dimension along axis 0 of @grid0 and // and it has halo-sizes of 1 and 2 on the sharded dim. - %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] - %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32> + %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] + %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32> - // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4 + // The tensor is sharded on its second dimension along axis 0 of @grid1d_4 // and it has pre-defined shard sizes. The shards of the devices will have // the following shapes: [4x2, 4x3, 4x4, 4x5] - %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] - %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32> + %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] + %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32> ``` }]; let arguments = (ins - FlatSymbolRefAttr:$mesh, - Mesh_MeshAxesArrayAttr:$split_axes, + FlatSymbolRefAttr:$grid, + Shard_GridAxesArrayAttr:$split_axes, DefaultValuedAttr:$static_sharded_dims_offsets, Variadic:$dynamic_sharded_dims_offsets, DefaultValuedAttr:$static_halo_sizes, Variadic:$dynamic_halo_sizes ); let results = (outs - Mesh_Sharding:$result + Shard_Sharding:$result ); let assemblyFormat = [{ - $mesh + $grid `split_axes` `=` $split_axes (`halo_sizes` `=` custom($dynamic_halo_sizes, $static_halo_sizes)^)? (`sharded_dims_offsets` `=` custom($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)? attr-dict `:` type($result) }]; let builders = [ - OpBuilder<(ins "FlatSymbolRefAttr":$mesh, - "ArrayRef":$split_axes, + OpBuilder<(ins "FlatSymbolRefAttr":$grid, + "ArrayRef":$split_axes, CArg<"ArrayRef", "{}">:$static_halo_sizes, CArg<"ArrayRef", "{}">:$static_sharded_dims_offsets)>, - OpBuilder<(ins "FlatSymbolRefAttr":$mesh, - "ArrayRef":$split_axes, + OpBuilder<(ins "FlatSymbolRefAttr":$grid, + "ArrayRef":$split_axes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>, - OpBuilder<(ins "llvm::StringRef":$mesh, - "ArrayRef":$split_axes, + OpBuilder<(ins "llvm::StringRef":$grid, + "ArrayRef":$split_axes, CArg<"ArrayRef", "{}">:$static_halo_sizes, CArg<"ArrayRef", "{}">:$static_sharded_dims_offsets )>, - OpBuilder<(ins "mlir::mesh::MeshSharding":$from)> + OpBuilder<(ins "mlir::shard::Sharding":$from)> ]; let hasVerifier = 1; let hasCanonicalizer = 1; } -def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> { +def Shard_GetShardingOp : Shard_Op<"get_sharding", [Pure]> { let summary = "Get the sharding of the given tensor."; let description = [{ - This operation returns the sharding of the given tensor as a MeshSharding. + This operation returns the sharding of the given tensor as a Sharding. }]; let arguments = (ins AnyRankedTensor:$source ); let results = (outs - Mesh_Sharding:$result + Shard_Sharding:$result ); let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($result) }]; } -def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [ +def Shard_ShardShapeOp : Shard_Op<"shard_shape", [ Pure, AttrSizedOperandSegments, DeclareOpInterfaceMethods ]> { let summary = "Get the shard shape for a given process/device."; let description = [{ - The device/process id is a multi-index of the device/process in the mesh. - This operation might be used during spmdization when the shard shape depends - on (non-constant) values used in `mesh.sharding`. + The device/process id is a multi-index of the device/process in the shard. + This operation might be used during partition when the shard shape depends + on (non-constant) values used in `shard.sharding`. }]; let arguments = (ins DenseI64ArrayAttr:$dims, Variadic:$dims_dynamic, - Mesh_Sharding:$sharding, + Shard_Sharding:$sharding, DenseI64ArrayAttr:$device, Variadic:$device_dynamic ); @@ -351,23 +351,23 @@ def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [ ]; } -def Mesh_ShardOp : Mesh_Op<"shard", [ +def Shard_ShardOp : Shard_Op<"shard", [ Pure, AllTypesMatch<["result", "src"]>, DeclareOpInterfaceMethods ]> { - let summary = "Annotate on how a tensor is sharded across a mesh."; + let summary = "Annotate on how a tensor is sharded across a shard."; let description = [{ - The mesh.shard operation is designed to specify and guide the sharding - behavior of a tensor value across a mesh topology. This operation has two + The shard.shard operation is designed to specify and guide the sharding + behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes: 1. `input`: This operand represents the tensor value that needs to be annotated for sharding. - 2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data - structure to represent distribution of a tensor on a mesh. it is typically defiend - by an `mesh.sharding` operation. + 2. `sharding`: This attribute is type of `ShardingType`, which is the core data + structure to represent distribution of a tensor on a shard. it is typically defined + by an `shard.sharding` operation. 3. `annotate_for_users`: A unit attribute addressing the scenario when a tensor's sharding annotation differs based on its context of use (either as @@ -378,36 +378,36 @@ def Mesh_ShardOp : Mesh_Op<"shard", [ Example: ``` - func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { - %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> + func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { + %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> ... } func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { - %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> + %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> ... } func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () { - %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> - %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> + %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> + %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ... } - // The first mesh.shard op applies to %arg0, the second mesh.shard op - // applies for the operand of op0, the third mesh.shard op applies for the + // The first shard.shard op applies to %arg0, the second shard.shard op + // applies for the operand of op0, the third shard.shard op applies for the // operand of op2 func.func @both_result_and_multi_operands_annotated( %arg0 : tensor<4x8xf32>) -> () { - %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> - %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> - %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding - %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> + %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> + %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> + %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding + %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> "op0"(%1) : ... "op1"(%2) : ... ... @@ -418,44 +418,44 @@ def Mesh_ShardOp : Mesh_Op<"shard", [ ``` func.func @annotate_on_same_result_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { - %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32> - %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32> + %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32> + %1 = shard.shard %0 to sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_result_same_value_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { - %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32> - %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32> + %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32> + %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_operand_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { - %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> - %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> + %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> + %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> ... } func.func @result_annotated_after_operand( %arg0 : tensor<4x8xf32>) -> () { - %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> - %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32> + %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> + %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32> ... } ``` }]; let arguments = (ins AnyRankedTensor:$src, - Mesh_Sharding:$sharding, + Shard_Sharding:$sharding, UnitAttr:$annotate_for_users ); let results = (outs @@ -473,34 +473,34 @@ def Mesh_ShardOp : Mesh_Op<"shard", [ // collective communication ops //===----------------------------------------------------------------------===// -class Mesh_CollectiveCommunicationOpBase< +class Shard_CollectiveCommunicationOpBase< string mnemonic, list traits = []> : - Mesh_Op, DeclareOpInterfaceMethods ])> { dag commonArgs = (ins - FlatSymbolRefAttr:$mesh, - DefaultValuedAttr:$mesh_axes + FlatSymbolRefAttr:$grid, + DefaultValuedAttr:$grid_axes ); } -def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [ +def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ Pure, SameOperandsAndResultElementType, SameOperandsAndResultRank, ]> { - let summary = "All-gather over a device mesh."; + let summary = "All-gather over a device grid."; let description = [{ Gathers along the `gather_axis` tensor axis. Example: ```mlir - mesh.mesh @mesh0(shape = 2x2) + shard.grid @grid0(shape = 2x2) ... - %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1 + %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1 : tensor<2x2xi8> -> tensor<2x4xi8> ``` Input: @@ -535,16 +535,16 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [ AnyNon0RankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; } -def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ +def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [ Pure, SameOperandsAndResultShape]> { - let summary = "All-reduce over a device mesh."; + let summary = "All-reduce over a device grid."; let description = [{ The accumulation element type is specified by the result type and it does not need to match the input element type. @@ -556,34 +556,34 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ Example: ``` - %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = + %1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = : tensor<3x4xf32> -> tensor<3x4xf64> ``` }]; let arguments = !con(commonArgs, (ins AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input, - DefaultValuedAttr:$reduction + DefaultValuedAttr:$reduction )); let results = (outs AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)? attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; let builders = [ - OpBuilder<(ins "Value":$input, "StringRef":$mesh, - "ArrayRef":$meshAxes, "ReductionKind":$reduction)> + OpBuilder<(ins "Value":$input, "StringRef":$grid, + "ArrayRef":$gridAxes, "ReductionKind":$reduction)> ]; } -def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ +def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ Pure, SameOperandsAndResultElementType, SameOperandsAndResultRank ]> { - let summary = "All-slice over a device mesh. This is the inverse of all-gather."; + let summary = "All-slice over a device grid. This is the inverse of all-gather."; let description = [{ Slice along the `slice_axis` tensor axis. This operation can be thought of as the inverse of all-gather. @@ -593,9 +593,9 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ Example: ```mlir - mesh.mesh @mesh0(shape = 2x2) + shard.grid @grid0(shape = 2x2) ... - %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1 + %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<2x4xi8> -> tensor<2x2xi8> ``` Input: @@ -630,30 +630,30 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ AnyNon0RankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; let builders = [ - OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef":$meshAxes, "int64_t":$sliceAxis)>, - OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef":$meshAxes, "int64_t":$sliceAxis)> + OpBuilder<(ins "Value":$input, "GridOp":$grid, "ArrayRef":$gridAxes, "int64_t":$sliceAxis)>, + OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$grid, "ArrayRef":$gridAxes, "int64_t":$sliceAxis)> ]; } -def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ +def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [ Pure, SameOperandsAndResultElementType, SameOperandsAndResultRank]> { - let summary = "All-to-all over a device mesh."; + let summary = "All-to-all over a device grid."; let description = [{ Performs an all-to-all on tensor pieces split along `split_axis`. The resulting pieces are concatenated along `concat_axis` on ech device. Example: ``` - mesh.mesh @mesh0(shape = 3) + shard.grid @grid0(shape = 3) ... - %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0] + %1 = shard.all_to_all %0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 0 : tensor<3x2xi8> -> tensor<3x2xi8> ``` @@ -687,7 +687,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ AnyNon0RankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `split_axis` `=` $split_axis `concat_axis` `=` $concat_axis attr-dict `:` type($input) `->` type($result) @@ -695,24 +695,24 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ let hasCanonicalizer = 1; } -def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ +def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ Pure, AllShapesMatch<["input", "result"]>, AllElementTypesMatch<["input", "result"]> ]> { - let summary = "Broadcast over a device mesh."; + let summary = "Broadcast over a device grid."; let description = [{ Broadcast the tensor on `root` to all devices in each respective group. - The operation broadcasts along mesh axes `mesh_axes`. + The operation broadcasts along grid axes `grid_axes`. The `root` device specifies the in-group multi-index that is broadcast to all other devices in the group. Example: ``` - mesh.mesh @mesh0(shape = 2x2) + shard.grid @grid0(shape = 2x2) - %1 = mesh.broadcast %0 on @mesh0 - mesh_axes = [0] + %1 = shard.broadcast %0 on @grid0 + grid_axes = [0] root = [0] : (tensor<2xi8>) -> tensor<2xi8> ``` @@ -744,31 +744,31 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `root` `=` custom($root_dynamic, $root) attr-dict `:` functional-type(operands, results) }]; let hasCanonicalizer = 1; } -def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ +def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [ Pure, AllRanksMatch<["input", "result"]>, AllElementTypesMatch<["input", "result"]> ]> { - let summary = "Gather over a device mesh."; + let summary = "Gather over a device grid."; let description = [{ Gathers on device `root` along the `gather_axis` tensor axis. - `root` specifies the coordinates of a device along `mesh_axes`. + `root` specifies the coordinates of a device along `grid_axes`. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior. Example: ```mlir - mesh.mesh @mesh0(shape = 2x2) + shard.grid @grid0(shape = 2x2) ... - %1 = mesh.gather %0 on @mesh0 mesh_axes = [1] + %1 = shard.gather %0 on @grid0 grid_axes = [1] gather_axis = 1 root = [1] : (tensor<2x2xi8>) -> tensor<2x4xi8> ``` @@ -807,7 +807,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ AnyNon0RankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis `root` `=` custom($root_dynamic, $root) attr-dict `:` functional-type(operands, results) @@ -815,11 +815,11 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ let hasCanonicalizer = 1; } -def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [ +def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [ AllShapesMatch<["input", "result"]>, AllElementTypesMatch<["input", "result"]> ]> { - let summary = "Send over a device mesh."; + let summary = "Send over a device grid."; let description = [{ Receive from a device within a device group. }]; @@ -832,21 +832,21 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`source` `=` custom($source_dynamic, $source)^)? attr-dict `:` functional-type(operands, results) }]; let hasCanonicalizer = 1; } -def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ +def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [ Pure, AllShapesMatch<["input", "result"]> ]> { - let summary = "Reduce over a device mesh."; + let summary = "Reduce over a device grid."; let description = [{ Reduces on device `root` within each device group. - `root` specifies the coordinates of a device along `mesh_axes`. + `root` specifies the coordinates of a device along `grid_axes`. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. @@ -858,14 +858,14 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ Example: ``` - %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0] + %1 = shard.reduce %0 on @grid0 grid_axes = [1, 0] reduction = root = [2, 3] : (tensor<3x4xf32>) -> tensor<3x4xf64> ``` }]; let arguments = !con(commonArgs, (ins AnyRankedTensor:$input, - DefaultValuedAttr:$reduction, + DefaultValuedAttr:$reduction, DenseI64ArrayAttr:$root, Variadic:$root_dynamic )); @@ -873,7 +873,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)? `root` `=` custom($root_dynamic, $root) attr-dict `:` functional-type(operands, results) @@ -881,19 +881,19 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ let hasCanonicalizer = 1; } -def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [ +def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter", [ Pure, SameOperandsAndResultRank]> { - let summary = "Reduce-scatter over a device mesh."; + let summary = "Reduce-scatter over a device grid."; let description = [{ After the reduction, the result is scattered within each device group. The tensor is split along `scatter_axis` and the pieces distributed across the device group. Example: ``` - mesh.mesh @mesh0(shape = 2x2) + shard.grid @grid0(shape = 2x2) ... - %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1] + %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = scatter_axis = 0 : tensor<3x4xf32> -> tensor<1x4xf64> ``` @@ -928,14 +928,14 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, - DefaultValuedAttr:$reduction, + DefaultValuedAttr:$reduction, IndexAttr:$scatter_axis )); let results = (outs AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)? `scatter_axis` `=` $scatter_axis attr-dict `:` type($input) `->` type($result) @@ -943,20 +943,20 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", let hasCanonicalizer = 1; } -def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ +def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ Pure, AllRanksMatch<["input", "result"]>, AllElementTypesMatch<["input", "result"]> ]> { - let summary = "Scatter over a device mesh."; + let summary = "Scatter over a device grid."; let description = [{ For each device group split the input tensor on the `root` device along axis `scatter_axis` and scatter the parts across the group devices. Example: ``` - mesh.mesh @mesh0(shape = 2x2) - %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] + shard.grid @grid0(shape = 2x2) + %1 = shard.scatter %0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<2x2xi8>) -> tensor<1x2xi8> @@ -1004,7 +1004,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `scatter_axis` `=` $scatter_axis `root` `=` custom($root_dynamic, $root) attr-dict `:` functional-type(operands, results) @@ -1012,11 +1012,11 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ let hasCanonicalizer = 1; } -def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [ +def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [ AllShapesMatch<["input", "result"]>, AllElementTypesMatch<["input", "result"]> ]> { - let summary = "Send over a device mesh."; + let summary = "Send over a device grid."; let description = [{ Send from one device to another within a device group. }]; @@ -1029,38 +1029,38 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `destination` `=` custom($destination_dynamic, $destination) attr-dict `:` functional-type(operands, results) }]; let hasCanonicalizer = 1; } -def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ +def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [ Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape ]> { - let summary = "Shift over a device mesh."; + let summary = "Shift over a device grid."; let description = [{ - Within each device group shift along mesh axis `shift_axis` by an offset + Within each device group shift along grid axis `shift_axis` by an offset `offset`. The result on devices that do not have a corresponding source is undefined. - `shift_axis` must be one of `mesh_axes`. + `shift_axis` must be one of `grid_axes`. If the `rotate` attribute is present, instead of a shift a rotation is done. Example: ``` - mesh.mesh @mesh0(shape = 2x4) - %1 = mesh.shift on @mesh0 mesh_axes = [1] + shard.grid @grid0(shape = 2x4) + %1 = shard.shift on @grid0 grid_axes = [1] shift_axis = 1 offset = 2 rotate : tensor<2xi8> -> tensor<2xi8> ``` Input: ``` - mesh axis 1 + grid axis 1 -----------> +----+----+----+----+ @@ -1089,7 +1089,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ AnyRankedTensor:$result ); let assemblyFormat = [{ - $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + $input `on` $grid (`grid_axes` `=` $grid_axes^)? `shift_axis` `=` $shift_axis `offset` `=` $offset (`rotate` $rotate^)? @@ -1098,7 +1098,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ let hasCanonicalizer = 1; } -def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ +def Shard_UpdateHaloOp : Shard_Op<"update_halo", [ Pure, DestinationStyleOpInterface, TypesMatchWith< @@ -1120,14 +1120,14 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ `destination_halo_sizes/static_destination_halo_sizes` in source shard and destination/result shard. - `split_axes` specifies for each tensor axis along which mesh axes its halo + `split_axes` specifies for each tensor axis along which grid axes its halo data is updated. }]; let arguments = (ins AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination, - FlatSymbolRefAttr:$mesh, - Mesh_MeshAxesArrayAttr:$split_axes, + FlatSymbolRefAttr:$grid, + Shard_GridAxesArrayAttr:$split_axes, Variadic:$halo_sizes, DefaultValuedAttr:$static_halo_sizes ); @@ -1136,7 +1136,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ ); let assemblyFormat = [{ $destination - `on` $mesh + `on` $grid `split_axes` `=` $split_axes (`halo_sizes` `=` custom($halo_sizes, $static_halo_sizes)^)? attr-dict `:` type($result) @@ -1145,4 +1145,4 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); } }]; } -#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD +#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_TD diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt similarity index 100% rename from mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt rename to mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h similarity index 52% rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h index 14aad7f9f6783..46f6ed410ebed 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h +++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ -#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ +#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_ +#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_ -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -20,24 +20,24 @@ class Operation; class IRMapping; class SymbolTableCollection; -namespace mesh { +namespace shard { -using ShardingArray = SmallVector>; -using ShardingArrayRef = ArrayRef>; +using ShardingArray = SmallVector>; +using ShardingArrayRef = ArrayRef>; struct ShardingOption { // An array of int array. The sub-array at the i-th position signifies the - // mesh axes the i-th loop will be sharded on. + // grid axes the i-th loop will be sharded on. ShardingArray shardingArray = {}; - FlatSymbolRefAttr mesh = nullptr; + FlatSymbolRefAttr grid = nullptr; // `empty` being true indicates that no sharding information can be inferred // at present. Note that it is different from the case where an operation is // not sharded. bool empty = false; ShardingOption() = default; - ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) - : shardingArray(std::move(shardingArray)), mesh(mesh) { - assert(this->mesh); + ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr grid) + : shardingArray(std::move(shardingArray)), grid(grid) { + assert(this->grid); } static ShardingOption makeEmpty() { auto res = ShardingOption(); @@ -46,21 +46,21 @@ struct ShardingOption { } }; -// This method retrieves the 'MeshSharding' from a given operation +// This method retrieves the 'Sharding' from a given operation // result and includes the 'annotate_for_users' information. -FailureOr> getMeshSharding(OpResult result); +FailureOr> getSharding(OpResult result); -// This method retrieves the 'MeshSharding' from a given operation +// This method retrieves the 'Sharding' from a given operation // operand and includes the 'annotate_for_users' information. -FailureOr> getMeshSharding(OpOperand &opOperand); +FailureOr> getSharding(OpOperand &opOperand); namespace detail { FailureOr -defaultGetShardingOption(Operation *op, ArrayRef operandShardings, - ArrayRef resultShardings); +defaultGetShardingOption(Operation *op, ArrayRef operandShardings, + ArrayRef resultShardings); -FailureOr> +FailureOr> defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption); @@ -71,18 +71,18 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b, } // namespace detail // Assumes full replication on all ranked tensor arguments and results. -void spmdizeFullyReplicatedOperation(Operation &op, - ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder); - -} // namespace mesh +void partitionFullyReplicatedOperation(Operation &op, + ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder); + +} // namespace shard } // namespace mlir /// Include the ODS generated interface header files. -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc" -#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ +#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td similarity index 80% rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td index a70d2c3e03851..8f5332b41ca72 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td +++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD -#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD +#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD +#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD include "mlir/IR/OpBase.td" @@ -16,7 +16,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { Interface for allowing operations to expose information needed to shard them. }]; - let cppNamespace = "::mlir::mesh"; + let cppNamespace = "::mlir::shard"; let methods = [ InterfaceMethod< @@ -84,8 +84,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { /*retTy=*/"FailureOr", /*methodName=*/"getShardingOption", /*args=*/(ins - "ArrayRef": $operandShardings, - "ArrayRef": $resultShardings + "ArrayRef": $operandShardings, + "ArrayRef": $resultShardings ), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -100,7 +100,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { This is what shardings the operands and results need to have in order to shard the op according to shardingOption. }], - /*retTy=*/"FailureOr>", + /*retTy=*/"FailureOr>", /*methodName=*/"getShardingAnnotations", /*args=*/(ins "const ShardingOption &":$shardingOption @@ -113,7 +113,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { >, InterfaceMethod< /*desc=*/[{ - Based on a given ShardingOption, this method adds `mesh.shard` + Based on a given ShardingOption, this method adds `shard.shard` operations for the operands and results that previously lacked sharding annotations. }], @@ -132,21 +132,21 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { InterfaceMethod< /*desc=*/[{ Convert self to SPMD form. - This method is used during the spmdization pass of a program fully + This method is used during the partition pass of a program fully annotated with shardings. - The spmdization algorithm would read the surrounding sharding + The partition algorithm would read the surrounding sharding annotations from the IR for each argument/result and prepare `operandShardings` and `resultShardings`. Values that are not ranked tensors do not have sharding annotations. - In this case their corresponding MeshSharding is null. + In this case their corresponding Sharding is null. - For convenience it will also prepare `spmdizedOperands`, although - they can be retrieved from the `spmdizationMap`. + For convenience it will also prepare `partitionedOperands`, although + they can be retrieved from the `partitionMap`. - The `spmdizationMap` contains a mapping from unsharded to - sharded/spmdized values that are constructed during the spmdization - pass. The interface implementation must populate `spmdizationMap` + The `partitionMap` contains a mapping from unsharded to + sharded/partitioned values that are constructed during the partition + pass. The interface implementation must populate `partitionMap` with the mapping for this op's results. `builder` is set to insert new operations in the appropriate point. @@ -158,20 +158,20 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { This assumes that all sharding annotations are for full replication. }], /*retTy=*/"LogicalResult", - /*methodName=*/"spmdize", + /*methodName=*/"partition", /*args=*/(ins - "ArrayRef": $spmdizedOperands, - "ArrayRef": $operandShardings, - "ArrayRef": $resultShardings, - "IRMapping&": $spmdizationMap, + "ArrayRef": $partitionedOperands, + "ArrayRef": $operandShardings, + "ArrayRef": $resultShardings, + "IRMapping&": $partitionMap, "SymbolTableCollection &": $symbolTableCollection, "OpBuilder &":$builder ), /*methodBody=*/"", /*defaultImplementation=*/[{ - spmdizeFullyReplicatedOperation( - *$_op.getOperation(), spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, symbolTableCollection, builder); + partitionFullyReplicatedOperation( + *$_op.getOperation(), partitionedOperands, operandShardings, + resultShardings, partitionMap, symbolTableCollection, builder); return success(); }]> ]; @@ -184,4 +184,4 @@ def ShardingInterface : OpInterface<"ShardingInterface"> { } -#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD +#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h similarity index 58% rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h index 2af8b2bd1d906..d34ba79257ff8 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ -#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ +#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_ +#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_ -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Value.h" @@ -20,35 +20,34 @@ class Operation; class IRMapping; class SymbolTableCollection; -namespace mesh { +namespace shard { -// Retrieve the mesh axes corresponding to each operation loop iterator based +// Retrieve the grid axes corresponding to each operation loop iterator based // on the provided shardings for the op's operands and results. // Assumes that the indexingMaps are projected permutations. -ShardingArray getMeshAxisAssignmentForLoopIterators( - ArrayRef operandShardings, - ArrayRef resultShardings, +ShardingArray getGridAxisAssignmentForLoopIterators( + ArrayRef operandShardings, ArrayRef resultShardings, ArrayRef loopIteratorTypes, ArrayRef indexingMaps); bool isAtLeastOneReductionIteratorSharded( ArrayRef loopIteratorTypes, - ArrayRef> meshAxisAssignmentForLoopIterators); + ArrayRef> gridAxisAssignmentForLoopIterators); -// Get the set of mesh axes that correspond to reduction loop iterators. -SmallVector getReductionMeshAxes( +// Get the set of grid axes that correspond to reduction loop iterators. +SmallVector getReductionGridAxes( ArrayRef loopIteratorTypes, - ArrayRef> meshAxisAssignmentForLoopIterators); + ArrayRef> gridAxisAssignmentForLoopIterators); // Inserts a clone of the operation that has all ranked tensor // arguments/results sharded. -void spmdizeTriviallyShardableOperation(Operation &op, - ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder); +void partitionTriviallyShardableOperation(Operation &op, + ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder); // All ranked tensor argument and result dimensions have // independent parallel loop iterators. @@ -73,15 +72,15 @@ struct IndependentParallelIteratorDomainShardingInterface return SmallVector(); } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTable, builder); + LogicalResult partition(Operation *op, ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + partitionTriviallyShardableOperation(*op, partitionedOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); return success(); } @@ -129,20 +128,20 @@ struct ElementwiseShardingInterface return maps; } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTable, builder); + LogicalResult partition(Operation *op, ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + partitionTriviallyShardableOperation(*op, partitionedOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); return success(); } }; -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ +#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..9e2c8d00b27f5 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shard) +add_public_tablegen_target(MLIRShardPassIncGen) +add_dependencies(mlir-headers MLIRShardPassIncGen) + +add_mlir_doc(Passes ShardPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h similarity index 61% rename from mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h rename to mlir/include/mlir/Dialect/Shard/Transforms/Partition.h index 2f6de3e134319..37903765903db 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h @@ -1,4 +1,4 @@ -//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// +//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,35 +6,35 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H -#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/DialectRegistry.h" namespace mlir { -namespace mesh { +namespace shard { -// Insert resharding spmdization of the value `sourceShardValue` +// Insert resharding partition of the value `sourceShardValue` // from sharding `source` to sharding `target`. // `sourceShardValue` is the already sharded value according to `source`. // // Example // // ```mlir -// mesh.mesh @mesh_1d(shape = 2) +// shard.grid @grid_1d(shape = 2) // ... -// %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8> -// %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> +// %1 = shard.shard %0 to <@grid_1d, [[0]]> : tensor<2xi8> +// %2 = shard.shard %1 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8> // ``` // // Will result in // // ```mlir -// %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : +// %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 : // tensor<1xi8> -> tensor<2xi8> // ``` -TypedValue reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, +TypedValue reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue sourceShardValue); TypedValue reshard(OpBuilder &builder, ShardOp source, @@ -44,7 +44,7 @@ TypedValue reshard(OpBuilder &builder, ShardOp source, void reshardingRegisterDependentDialects(DialectRegistry ®istry); -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h similarity index 75% rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.h index a2424d43a8ba9..88bb460255728 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h @@ -1,4 +1,4 @@ -//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===// +//===- Passes.h - Shard Passes ----------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H -#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H #include "mlir/Pass/Pass.h" @@ -17,7 +17,7 @@ namespace func { class FuncOp; } -namespace mesh { +namespace shard { /// This enum controls the traversal order for the sharding propagation. enum class TraversalOrder { @@ -36,16 +36,16 @@ enum class TraversalOrder { //===----------------------------------------------------------------------===// #define GEN_PASS_DECL -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// #define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td similarity index 65% rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.td index 11ec7e78cd5e6..bbc6a1977b13e 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td @@ -1,4 +1,4 @@ -//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===// +//===-- Passes.td - Shard transformation definition file ---*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD -#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD include "mlir/Pass/PassBase.td" @@ -20,31 +20,31 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO let summary = "sharding propagation"; let description = [{ Propagates sharding information throughout the graph. After this pass, each - of the operations' operands and results is annotated with a `mesh.shard` + of the operations' operands and results is annotated with a `shard.shard` operation, and the operations themselves are added with sharding option attributes. }]; let options = [ Option<"traversal", "traversal", - "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward", + "mlir::shard::TraversalOrder", /*default=*/"mlir::shard::TraversalOrder::BackwardForward", "Traversal order to use for sharding propagation:", [{::llvm::cl::values( - clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward", + clEnumValN(mlir::shard::TraversalOrder::Forward, "forward", "Forward only traversal."), - clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward", + clEnumValN(mlir::shard::TraversalOrder::Backward, "backward", "backward only traversal."), - clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward", + clEnumValN(mlir::shard::TraversalOrder::ForwardBackward, "forward-backward", "forward-backward traversal."), - clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward", + clEnumValN(mlir::shard::TraversalOrder::BackwardForward, "backward-forward", "backward-forward traversal.") )}]>, ]; let dependentDialects = [ - "mesh::MeshDialect" + "shard::ShardDialect" ]; } -def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> { +def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> { let summary = "Partition a function into SPMD form."; let description = [{ This pass fits in right after a pass that annotates the function with @@ -52,15 +52,15 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> It operates on a fully annotated IR. A fully annotated IR required that all ranked tensor operands, results and - block arguments are annotated with the `mesh.shard` operation. + block arguments are annotated with the `shard.shard` operation. All direct descendant operations in the function must implement the `ShardingInterface` interface or all their ranked tensor operands and results must have full replication sharding. The input IR must have sharding annotations such that each operation - that implements `ShardingInterface` can handle during spmdization with - its `spmdize` method. + that implements `ShardingInterface` can handle during partition with + its `partition` method. This can be achieved with the `ShardingPropagation` pass. If the function has multiple terminating blocks, @@ -70,36 +70,36 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> Example: ```mlir - mesh.mesh @mesh_1d(shape = 2) + shard.grid @grid_1d(shape = 2) func.func @f( %arg0: tensor<2xi8> ) -> tensor<2xi8> { - %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8> - %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + %0 = shard.shard %arg0 to <@grid_1d, [[0]]> : tensor<2xi8> + %1 = shard.shard %0 to <@grid_1d, [[0]]> annotate_for_users: tensor<2xi8> %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8> - %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> + %3 = shard.shard %2 to <@grid_1d, [[0]]> : tensor<2xi8> + %4 = shard.shard %3 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8> return %4 : tensor<2xi8> } ``` - Spmdizing the above would result in + Partitioning the above would result in * Performing the element-wise `abs` operation on each device. * Resharding to full replication with an all-gather. ```mlir - mesh.mesh @mesh_1d(shape = 2) + shard.grid @grid_1d(shape = 2) func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> { %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8> - %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> + %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> return %1 : tensor<2xi8> } ``` }]; let dependentDialects = [ - "mesh::MeshDialect" + "shard::ShardDialect" ]; } -#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md similarity index 87% rename from mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md rename to mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md index 6368931cf6e07..cf5ae12b54b2c 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md +++ b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md @@ -1,6 +1,6 @@ -# Resharding Spmdization Examples +# Resharding Partition Examples -Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh. +Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` shard. unsharded `2x3` tensor ``` @@ -8,16 +8,16 @@ unsharded `2x3` tensor 21 22 23 ``` -sharded on a `2x3` mesh +sharded on a `2x3` grid sharding = `[[0, 1]]` -mesh contents: +grid contents: ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+ mesh axis 0 | ++----+----+----+ grid axis 0 | | 11 | 12 | 13 | | +----+----+----+ | | 21 | 22 | 23 | | @@ -27,9 +27,9 @@ mesh axis 1 Transform into sharding = `[[1, 0]]` ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+ mesh axis 0 | ++----+----+----+ grid axis 0 | | 11 | 13 | 22 | | +----+----+----+ | | 12 | 21 | 23 | | @@ -40,7 +40,7 @@ Swap contents on devices that have the same linear index in the 2 shardings. -------------------------------------------------------------- -Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh. +Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` shard. unsharded `2x3` tensor ``` @@ -48,15 +48,15 @@ unsharded `2x3` tensor 21 22 23 ``` -sharded on a `2x3` mesh +sharded on a `2x3` grid sharding = `[[0, 1]]` -mesh contents: +grid contents: ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+ mesh axis 0 | ++----+----+----+ grid axis 0 | | 11 | 12 | 13 | | +----+----+----+ | | 21 | 22 | 23 | | @@ -66,9 +66,9 @@ mesh axis 1 Transform into sharding = `[[1]]` ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+ mesh axis 0 | ++----+----+----+ grid axis 0 | | 11 | 12 | 13 | | | 21 | 22 | 23 | | +----+----+----+ | @@ -77,11 +77,11 @@ mesh axis 1 +----+----+----+ ↓ ``` Algorithm: -All-gather along mesh axis 0. +All-gather along grid axis 0. -------------------------------------------------------------- -Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh. +Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` shard. unsharded `4x6` tensor ``` @@ -89,15 +89,15 @@ unsharded `4x6` tensor 21 22 23 24 25 26 ``` -sharded on a `2x3` mesh +sharded on a `2x3` grid sharding = `[[], [0, 1]]` -mesh contents: +grid contents: ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+ mesh axis 0 | ++----+----+----+ grid axis 0 | | 11 | 12 | 13 | | | 21 | 22 | 23 | | +----+----+----+ | @@ -108,9 +108,9 @@ mesh axis 1 Transform into sharding = `[[], [0]]` ``` -mesh axis 1 +grid axis 1 -----------> -+----------+----------+ mesh axis 0 | ++----------+----------+ grid axis 0 | | 11 12 13 | 11 12 13 | | | 21 22 23 | 21 22 23 | | +----------+----------+ | @@ -119,11 +119,11 @@ mesh axis 1 +----------+----------+ ↓ ``` Algorithm: -All-gather along mesh axis 1. +All-gather along grid axis 1. -------------------------------------------------------------- -Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh. +Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` shard. unsharded `4x8` tensor ``` @@ -132,15 +132,15 @@ unsharded `4x8` tensor 31 32 33 34 35 36 37 38 41 42 43 44 45 46 47 48 ``` -sharded on a `2x2x2` mesh +sharded on a `2x2x2` grid sharding = `[[0], [1, 2]]` -mesh contents: +grid contents: ``` -mesh axis 2 +grid axis 2 -----------> -+-------+-------+ mesh axis 1 | mesh axis 0 | ++-------+-------+ grid axis 1 | grid axis 0 | | 11 12 | 13 14 | | | | 21 22 | 23 24 | | | +-------+-------+ | | @@ -158,9 +158,9 @@ mesh axis 2 Transform into sharding = `[[0], [2]]` ``` -mesh axis 2 +grid axis 2 -----------> -+-------------+-------------+ mesh axis 1 | mesh axis 0 | ++-------------+-------------+ grid axis 1 | grid axis 0 | | 11 12 13 14 | 15 16 17 18 | | | | 21 22 23 24 | 25 26 27 28 | | | +-------------+-------------+ | | @@ -177,13 +177,13 @@ mesh axis 2 ``` Algorithm: -Can't be done with just an all-gather along mesh axis 1. +Can't be done with just an all-gather along grid axis 1. Can be handled by multiple resharding transformations `[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]` -------------------------------------------------------------- -Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh. +Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard. unsharded `6x6` tensor ``` @@ -194,13 +194,13 @@ unsharded `6x6` tensor 51 52 53 54 55 56 61 62 63 64 65 66 ``` -sharded on a `2x3` mesh +sharded on a `2x3` grid sharding = `[[0], [1]]` ``` -mesh axis 1 +grid axis 1 -----------> -+-------+-------+-------+ mesh axis 0 | ++-------+-------+-------+ grid axis 0 | | 11 12 | 13 14 | 15 16 | | | 21 22 | 23 24 | 25 26 | | | 31 32 | 33 34 | 35 36 | | @@ -213,9 +213,9 @@ mesh axis 1 transform to sharding = `[[1], [0]]` ``` -mesh axis 1 +grid axis 1 -----------> -+----------+----------+----------+ mesh axis 0 | ++----------+----------+----------+ grid axis 0 | | 11 12 13 | 31 32 33 | 51 52 53 | | | 21 22 23 | 41 42 43 | 61 62 63 | | +----------+----------+----------+ | @@ -223,9 +223,9 @@ mesh axis 1 | 24 25 26 | 44 45 46 | 64 65 66 | | +----------+----------+----------+ ↓ -mesh axis 0 +grid axis 0 -----------> -+----------+----------+ mesh axis 1 | ++----------+----------+ grid axis 1 | | 11 12 13 | 14 15 16 | | | 21 22 23 | 24 25 26 | | +----------+----------+ | @@ -240,7 +240,7 @@ Algorithm: TODO -------------------------------------------------------------- -Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh. +Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` shard. unsharded 6x6 tensor ``` @@ -251,13 +251,13 @@ unsharded 6x6 tensor 51 52 53 54 55 56 61 62 63 64 65 66 ``` -shard on `2x6` mesh +shard on `2x6` grid sharding = `[[0], [1]]` ``` -mesh axis 1 +grid axis 1 -----------> -+----+----+----+----+----+----+ mesh axis 0 | ++----+----+----+----+----+----+ grid axis 0 | | 11 | 12 | 13 ‖ 14 | 15 | 16 | | | 21 | 22 | 23 ‖ 24 | 23 | 26 | | | 31 | 32 | 33 ‖ 34 | 35 | 36 | | @@ -270,9 +270,9 @@ mesh axis 1 transform to sharding = `[[1], [0]]` ``` -mesh axis 0 +grid axis 0 -----------> -+----------+----------+ mesh axis 1 | ++----------+----------+ grid axis 1 | | 11 12 13 | 14 15 16 | | +----------+----------+ | | 21 22 23 | 24 25 26 | | @@ -290,9 +290,9 @@ Algorithm: TODO -------------------------------------------------------------- -Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh. +Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` shard. -`M x N` mesh. +`M x N` shard. `K x L` tensor `t`. `d(m, n)` the tensor on device `(m, n)`. @@ -433,9 +433,9 @@ TODO -------------------------------------------------------------- -Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh. +Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard. -Device placement on a `2x3` mesh +Device placement on a `2x3` grid ``` 11 12 13 <- devices 21 22 23 @@ -512,7 +512,7 @@ TODO -------------------------------------------------------------- -Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh. +Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` shard. unsharded `6x6` tensor ``` @@ -523,11 +523,11 @@ unsharded `6x6` tensor 51 52 53 54 55 56 61 62 63 64 65 66 ``` -sharded on a `3` mesh +sharded on a `3` grid sharding = `[[0], []]` ``` -+-------------------+ mesh axis 0 | ++-------------------+ grid axis 0 | | 11 12 13 14 15 16 | | | 21 22 23 24 25 26 | | +-------------------+ | @@ -541,7 +541,7 @@ sharding = `[[0], []]` transform to sharding = `[[], [0]]` ``` -mesh axis 0 +grid axis 0 -----------> +-------+-------+-------+ | 11 12 | 13 14 | 15 16 | @@ -554,11 +554,11 @@ mesh axis 0 ``` Algorithm: ```mlir -%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8> +%1 = all_to_all %0 on @grid grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8> ``` -------------------------------------------------------------- -Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh. +Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` shard. unsharded `4x4` tensor ``` @@ -567,13 +567,13 @@ unsharded `4x4` tensor 31 32 33 34 41 42 43 44 ``` -sharded on a `2x2x2` mesh +sharded on a `2x2x2` grid sharding = `[[0], [1, 2]]` ``` -mesh axis 2 +grid axis 2 -----------> -+----+----+ mesh axis 1 | mesh axis 0 | ++----+----+ grid axis 1 | grid axis 0 | | 11 | 12 | | | | 21 | 22 | | | +----+----+ | | @@ -591,9 +591,9 @@ mesh axis 2 transform to sharding = `[[0, 1], [2]]` ``` -mesh axis 2 +grid axis 2 -----------> -+-------+-------+ mesh axis 1 | mesh axis 0 | ++-------+-------+ grid axis 1 | grid axis 0 | | 11 12 | 13 41 | | | +-------+-------+ | | | 21 22 | 23 24 | | | @@ -606,7 +606,7 @@ mesh axis 2 ``` Algorithm: ```mlir -%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8> +%1 = all_to_all %0 on @grid grid_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8> ``` is not enough. @@ -639,15 +639,15 @@ Basis: [[0]] -> [[]] [[0, 1]] -> [[1]] ``` - All-gather along mesh axis 0. + All-gather along grid axis 0. -* Swap mesh axes order when assigned to the same tensor axis. +* Swap grid axes order when assigned to the same tensor axis. ``` [[0, 1]] -> [[1, 0]] ``` Swap contents on devices with the same linear index. -* Move mesh axis to different tensor dimension. +* Move grid axis to different tensor dimension. ``` [[0], []] -> [[], [0]] ``` @@ -661,9 +661,9 @@ Example decomposition of ``` into ``` -[[0], [1]] -> all-gather along mesh axis 1 -> -[[0], []] -> all-to-all along mesh axis 0 -> -[[], [0]] -> extract slice along mesh axis 1 -> +[[0], [1]] -> all-gather along grid axis 1 -> +[[0], []] -> all-to-all along grid axis 0 -> +[[], [0]] -> extract slice along grid axis 1 -> [[1], [0]] ``` @@ -675,9 +675,9 @@ Example decomposition of ``` into ``` -[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 -> -[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 -> -[[3], [1, 2], [0]] -> all-gather along mesh axis 3 -> -[[], [1, 2], [0]] -> all-to-all along mesh axis 0 -> +[[3, 2], [], [0, 1]] -> all-to-all along grid axis 1 -> +[[3, 2], [1], [0]] -> all-to-all along grid axis 2 -> +[[3], [1, 2], [0]] -> all-gather along grid axis 3 -> +[[], [1, 2], [0]] -> all-to-all along grid axis 0 -> [[0], [1, 2], []] ``` diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h similarity index 93% rename from mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h index 243dbf081b999..452d4f6b4ed61 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h @@ -1,4 +1,4 @@ -//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// +//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H -#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/EndomorphismSimplification.h" #include "llvm/Support/Casting.h" @@ -22,7 +22,7 @@ namespace mlir { class SymbolTableCollection; -namespace mesh { +namespace shard { // If we have an algebraic op like "+" and a summing all-reduce, // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to @@ -112,7 +112,7 @@ void populateSimplificationPatterns( void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h similarity index 65% rename from mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h rename to mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h index f46c0db846088..57d65e687ea35 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h @@ -1,4 +1,4 @@ -//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===// +//===- Transforms.h - Shard Transforms --------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H -#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -20,7 +20,7 @@ class RewritePatternSet; class SymbolTableCollection; class DialectRegistry; class ImplicitLocOpBuilder; -namespace mesh { +namespace shard { void populateProcessMultiIndexOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); @@ -35,20 +35,20 @@ void populateAllOpLoweringPatterns( void registerAllOpLoweringDialects(DialectRegistry ®istry); TypedValue -createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, +createCollectiveProcessGroupSize(GridOp grid, ArrayRef axes, ImplicitLocOpBuilder &builder); -// Get process linear index along the given mesh axes. -TypedValue createProcessLinearIndex(StringRef mesh, - ArrayRef meshAxes, +// Get process linear index along the given grid axes. +TypedValue createProcessLinearIndex(StringRef grid, + ArrayRef gridAxes, ImplicitLocOpBuilder &builder); -// Get process linear index from a multi-index along the given mesh axes . +// Get process linear index from a multi-index along the given grid axes . TypedValue -createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, - ArrayRef meshAxes, +createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, + ArrayRef gridAxes, ImplicitLocOpBuilder &builder); -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h similarity index 88% rename from mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h rename to mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h index cfac485b807f2..895e7e5939935 100644 --- a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h +++ b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h @@ -1,4 +1,4 @@ -//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===// +//===- ShardingExtensions.h - -------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index c6fcf1a0d510b..856170e9308da 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -60,7 +60,6 @@ #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -77,6 +76,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -131,7 +131,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { LLVM::LLVMDialect, math::MathDialect, memref::MemRefDialect, - mesh::MeshDialect, + shard::ShardDialect, ml_program::MLProgramDialect, mpi::MPIDialect, nvgpu::NVGPUDialect, diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index dd8b292a87344..002ff61fb87dd 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -32,13 +32,13 @@ #include "mlir/Dialect/MLProgram/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" @@ -81,7 +81,7 @@ inline void registerAllPasses() { LLVM::registerLLVMPasses(); math::registerMathPasses(); memref::registerMemRefPasses(); - mesh::registerMeshPasses(); + shard::registerShardPasses(); ml_program::registerMLProgramPasses(); quant::registerQuantPasses(); registerSCFPasses(); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index f84375b6b8d6a..785cb8293810c 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) -add_subdirectory(MeshToMPI) +add_subdirectory(ShardToMPI) add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt similarity index 65% rename from mlir/lib/Conversion/MeshToMPI/CMakeLists.txt rename to mlir/lib/Conversion/ShardToMPI/CMakeLists.txt index 15560aa61e145..564f36fd20abb 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRMeshToMPI - MeshToMPI.cpp +add_mlir_conversion_library(MLIRShardToMPI + ShardToMPI.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI DEPENDS MLIRConversionPassIncGen @@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI MLIRLinalgTransforms MLIRMemRefDialect MLIRPass - MLIRMeshDialect + MLIRShardDialect MLIRMPIDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp similarity index 92% rename from mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp rename to mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 63b1fdabaf407..8525543760d99 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -1,4 +1,4 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communication ops tp MPI ops. +// This file implements a translation of Shard communication ops to MPI ops. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -20,11 +20,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" @@ -35,16 +35,16 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "mesh-to-mpi" +#define DEBUG_TYPE "shard-to-mpi" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; -using namespace mesh; +using namespace shard; namespace { /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -188,18 +188,18 @@ struct ConvertShardingOp : public OpConversionPattern { // maxSplitSize+1}. Store the offsets in the tensor but set trailing // elements for smaller split-groups to -1. Computing the max size of the // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) + // GridOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); int64_t maxSplitSize = 0; for (auto axes : splitAxes) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic); maxSplitSize = std::max(maxSplitSize, splitSize); } @@ -218,7 +218,7 @@ struct ConvertShardingOp : public OpConversionPattern { int64_t curr = 0; for (auto [i, axes] : llvm::enumerate(splitAxes)) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef values(&offsets[curr], splitSize); @@ -264,20 +264,20 @@ struct ConvertProcessMultiIndexOp SymbolTableCollection symbolTableCollection; Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(op, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); SmallVector dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - // optionally extract subset of mesh axes + // optionally extract subset of grid axes auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector subIndex; @@ -338,12 +338,12 @@ struct ConvertNeighborsLinearIndicesOp Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; @@ -394,14 +394,14 @@ struct ConvertShardShapeOp : public OpConversionPattern { auto sharding = op.getSharding().getDefiningOp(); if (!sharding) { return op->emitError() - << "Expected SharingOp as defining op for sharding" + << "Expected ShardingOp as defining op for sharding" << " but found " << adaptor.getSharding()[0].getDefiningOp(); } // Compute the sharded shape by applying the sharding to the input shape. // If shardedDimsOffsets is not defined in the sharding, the shard shape is // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in + // dimension (which is given by the size of the grid axes provided in // split-axes). Odd elements get distributed to trailing shards. If a // shardedDimsOffsets is provided, the shard shape is computed by // subtracting the offset of the current shard from the offset of the next @@ -431,11 +431,11 @@ struct ConvertShardShapeOp : public OpConversionPattern { SmallVector multiIdx = getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + // Get the GridOp, the grid shape is needed to compute the sharded shape. SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(sharding, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); auto splitAxes = sharding.getSplitAxes().getAxes(); @@ -455,7 +455,7 @@ struct ConvertShardShapeOp : public OpConversionPattern { tmp); } - // With static mesh shape the sizes of the split axes are known. + // With static grid shape the sizes of the split axes are known. // Hence the start/pos for each split axes in shardDimsOffsets can be // computed statically. int64_t pos = 0; @@ -475,10 +475,10 @@ struct ConvertShardShapeOp : public OpConversionPattern { // Create a value from the static position in shardDimsOffsets. Value posVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. + // Get the index of the local shard in the grid axis. Value idx = multiIdx[axes[0]]; auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); if (shardedDimsOffs) { // If sharded dims offsets are provided, use them to compute the // sharded shape. @@ -556,13 +556,13 @@ struct ConvertAllReduceOp : public OpConversionPattern { matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto grid = adaptor.getGrid(); + mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "No grid found for AllReduceOp"; + if (ShapedType::isDynamicShape(gridOp.getShape())) return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; + << "Dynamic grid shape not supported in AllReduceOp"; ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); Value input = adaptor.getInput(); @@ -592,27 +592,27 @@ struct ConvertAllReduceOp : public OpConversionPattern { linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh + // The color is the linear index of the process in the grid along the + // non-reduced axes. The key is the linear index of the process in the grid // along the reduced axes. - SmallVector indexResultTypes(meshOp.getShape().size(), + SmallVector indexResultTypes(gridOp.getShape().size(), iBuilder.getIndexType()); SmallVector myMultiIndex = - ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) .getResult(); Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector multiKey(myMultiIndex.size(), zero); - auto redAxes = adaptor.getMeshAxes(); + auto redAxes = adaptor.getGridAxes(); for (auto axis : redAxes) { multiKey[axis] = myMultiIndex[axis]; myMultiIndex[axis] = zero; } Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); + createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); + Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator @@ -698,8 +698,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); + auto grid = adaptor.getGrid(); + auto gridOp = getGrid(op, symbolTableCollection); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast(sz)) @@ -745,10 +745,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); - SmallVector indexResultTypes(meshOp.getShape().size(), + SmallVector indexResultTypes(gridOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -759,7 +759,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { // Get the linearized ids of the neighbors (down and up) for the // given split auto tmp = rewriter - .create(loc, mesh, myMultiIndex, + .create(loc, grid, myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... @@ -791,7 +791,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor + // Processes on the grid borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto hasFrom = arith::CmpIOp::create( @@ -869,8 +869,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { } }; -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase { +struct ConvertShardToMPIPass + : public impl::ConvertShardToMPIPassBase { using Base::Base; /// Run the dialect converter on the module. @@ -879,12 +879,12 @@ struct ConvertMeshToMPIPass RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); - // Define a type converter to convert mesh::ShardingType, + // Define a type converter to convert shard::ShardingType, // mostly for use in return operations. TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - // convert mesh::ShardingType to a tuple of RankedTensorTypes + // convert shard::ShardingType to a tuple of RankedTensorTypes typeConverter.addConversion( [](ShardingType type, SmallVectorImpl &results) -> std::optional { @@ -920,10 +920,10 @@ struct ConvertMeshToMPIPass return results; }); - // No mesh dialect should left after conversion... - target.addIllegalDialect(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp(); + // No shard dialect should left after conversion... + target.addIllegalDialect(); + // ...except the global GridOp. GridShapeOp which will get folded later. + target.addLegalOp(); // Allow all the stuff that our patterns will convert to target.addLegalDialect< BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, @@ -951,7 +951,7 @@ struct ConvertMeshToMPIPass // Folding patterns cannot be mixed with conversion patterns -> extra pass. patterns.clear(); SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); + mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index f96bda603baa6..93682a9375dac 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -27,7 +27,7 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect - MLIRMeshDialect + MLIRShardDialect MLIRPass MLIRShardingInterface MLIRTensorDialect diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index dd6efe6d6bc31..3e34246f66f2c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -6,22 +6,22 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/DialectRegistry.h" using namespace mlir; using namespace mlir::arith; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { // Sharding of arith.constant // RankedTensor constants can be sharded like any other tensor. // %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding +// %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding // Scalar constants are always replicated and need no sharding annotation. struct ConstantShardingInterface @@ -48,8 +48,8 @@ struct ConstantShardingInterface // Otherwise mirror result sharding if it is a tensor constant. // Otherwise return replication option. FailureOr - getShardingOption(Operation *op, ArrayRef operandShardings, - ArrayRef resultShardings) const { + getShardingOption(Operation *op, ArrayRef operandShardings, + ArrayRef resultShardings) const { assert(resultShardings.size() == 1 && "Expecting exactly one result sharding for arith.constant"); auto resultSharding = resultShardings[0]; @@ -61,17 +61,17 @@ struct ConstantShardingInterface for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) { axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); } - return ShardingOption(axesArray, resultSharding.getMeshAttr()); + return ShardingOption(axesArray, resultSharding.getGridAttr()); } - return ShardingOption({}, resultSharding.getMeshAttr()); + return ShardingOption({}, resultSharding.getGridAttr()); } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef partitiondOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { auto cOp = cast(op); if (auto value = dyn_cast(cOp.getValue())) { if (!value.isSplat() || !resultShardings[0]) { @@ -80,15 +80,15 @@ struct ConstantShardingInterface } auto sharding = resultShardings[0]; auto newType = cast(shardType( - cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), + cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable), sharding)); auto newValue = value.resizeSplat(newType); auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue); - spmdizationMap.map(op->getResult(0), newOp.getResult()); - spmdizationMap.map(op, newOp.getOperation()); + partitionMap.map(op->getResult(0), newOp.getResult()); + partitionMap.map(op, newOp.getOperation()); } else { // `clone` will populate the mapping of old to new results. - (void)builder.clone(*op, spmdizationMap); + (void)builder.clone(*op, partitionMap); } return success(); } diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 3cc52ebc0a8d9..053ee95e92053 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -19,7 +19,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(Math) add_subdirectory(MemRef) -add_subdirectory(Mesh) +add_subdirectory(Shard) add_subdirectory(MLProgram) add_subdirectory(MPI) add_subdirectory(NVGPU) diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp index eb6b59bb00f1b..1b18ef2dd04a7 100644 --- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp +++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp @@ -8,7 +8,7 @@ #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" -#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt index 47363f48f95cc..87ef51e63f1da 100644 --- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt @@ -1,7 +1,7 @@ set(LLVM_OPTIONAL_SOURCES AllExtensions.cpp InlinerExtension.cpp - MeshShardingExtensions.cpp + ShardingExtensions.cpp ) add_mlir_extension_library(MLIRFuncInlinerExtension @@ -17,8 +17,8 @@ add_mlir_extension_library(MLIRFuncInlinerExtension MLIRFuncDialect ) -add_mlir_extension_library(MLIRFuncMeshShardingExtensions - MeshShardingExtensions.cpp +add_mlir_extension_library(MLIRFuncShardingExtensions + ShardingExtensions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions @@ -38,5 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions LINK_LIBS PUBLIC MLIRFuncInlinerExtension - MLIRFuncMeshShardingExtensions + MLIRFuncShardingExtensions ) diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp similarity index 68% rename from mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp rename to mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp index da508cc95bfe1..dfd1348c24441 100644 --- a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp @@ -1,4 +1,4 @@ -//===- MeshShardingExtensions.cpp - ---------------------------------------===// +//===- ShardingExtensions.cpp - ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/IR/MLIRContext.h" namespace mlir::func { @@ -16,7 +16,7 @@ namespace mlir::func { void registerShardingInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) { ReturnOp::attachInterface< - mesh::IndependentParallelIteratorDomainShardingInterface>( + shard::IndependentParallelIteratorDomainShardingInterface>( *ctx); }); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp index b6e168e95ee86..7f6ecab2d90f5 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -15,7 +15,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/SubsetOpInterface.h" @@ -119,8 +119,8 @@ void mlir::linalg::LinalgDialect::initialize() { addInterfaces(); - declarePromisedInterface(); - declarePromisedInterfaces(); + declarePromisedInterfaces(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp index 281d9f2204486..ba94ad7906ab7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp @@ -10,14 +10,14 @@ #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" void mlir::linalg::registerAllDialectInterfaceImplementations( DialectRegistry ®istry) { registerBufferizableOpInterfaceExternalModels(registry); - registerMeshShardingInterfaceExternalModels(registry); + registerShardingInterfaceExternalModels(registry); registerSubsetOpInterfaceExternalModels(registry); registerTilingInterfaceExternalModels(registry); registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 69e6fdabf9a58..70f846e5bbd20 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Interchange.cpp Loops.cpp TransposeMatmul.cpp - MeshShardingInterfaceImpl.cpp + ShardingInterfaceImpl.cpp NamedOpConversions.cpp BlockPackMatmul.cpp PackAndUnpackPatterns.cpp @@ -68,7 +68,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRIR MLIRMemRefDialect MLIRMemRefTransforms - MLIRMeshTransforms + MLIRShardTransforms MLIRLinalgDialect MLIRLinalgUtils MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp similarity index 65% rename from mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp rename to mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp index 83d12e314a36f..f277c5f5be5fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- MeshShardingInterfaceImpl.cpp --------------------------------------===// +//===- ShardingInterfaceImpl.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,18 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -36,13 +36,13 @@ namespace mlir::linalg { -using MeshAxis = mesh::MeshAxis; -using ReductionKind = mesh::ReductionKind; -using MeshSharding = mesh::MeshSharding; -using ShardingArray = mesh::ShardingArray; -using MeshOp = mesh::MeshOp; +using GridAxis = shard::GridAxis; +using ReductionKind = shard::ReductionKind; +using Sharding = shard::Sharding; +using ShardingArray = shard::ShardingArray; +using GridOp = shard::GridOp; -// Returns the corresponding mesh reduction kind for the given arith op. +// Returns the corresponding grid reduction kind for the given arith op. static ReductionKind getReductionKind(Operation *op) { return llvm::TypeSwitch(op) // Floating-point operations. @@ -97,18 +97,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { return getReductionKind(reductionOp.value()); } -static MeshOp getMesh(Operation *op, ArrayRef operandShardings, - ArrayRef resultShardings, +static GridOp getGrid(Operation *op, ArrayRef operandShardings, + ArrayRef resultShardings, SymbolTableCollection &symbolTable) { - for (const MeshSharding &sharding : operandShardings) { + for (const Sharding &sharding : operandShardings) { if (sharding) { - return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); + return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } - for (const MeshSharding &sharding : resultShardings) { + for (const Sharding &sharding : resultShardings) { if (sharding) { - return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); + return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } @@ -117,29 +117,29 @@ static MeshOp getMesh(Operation *op, ArrayRef operandShardings, } // Choose the operand based on the current process index along the reduction -// mesh axes. +// grid axes. // We need to use the initial value only once to avoid including it in the // reduction multiple times. // In each process group only the leading process with linear index 0 would use // the original operand. // The other processes would use the reduction operation neutral tensor. static Value createDestinationPassingStyleInitOperand( - LinalgOp op, int operandNumber, Value spmdizedOperand, - ArrayRef reductionMeshAxes, MeshOp meshOp, + LinalgOp op, int operandNumber, Value partitionedOperand, + ArrayRef reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder) { - Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( - meshOp.getSymName(), reductionMeshAxes, builder); + Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex( + gridOp.getSymName(), reductionGridAxes, builder); Value zero = arith::ConstantIndexOp::create(builder, 0); Value isLeadProcess = arith::CmpIOp::create( builder, builder.getI1Type(), arith::CmpIPredicate::eq, processLinearIndexInReductionGroup, zero); - scf::IfOp ifOp = scf::IfOp::create(builder, spmdizedOperand.getType(), + scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(), isLeadProcess, true, true); // Then block. { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); - scf::YieldOp::create(builder, spmdizedOperand); + scf::YieldOp::create(builder, partitionedOperand); } // Else block. @@ -147,7 +147,7 @@ static Value createDestinationPassingStyleInitOperand( OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); SmallVector shape = - tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); + tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand); SmallVector combinerOps; matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); @@ -167,73 +167,72 @@ static Value createDestinationPassingStyleInitOperand( return ifOp.getResult(0); } -// Create the DPS init operands for the spmdized Linalg op. -// Return all the new spmdized operands. +// Create the DPS init operands for the partitioned Linalg op. +// Return all the new partitioned operands. static SmallVector createDestinationPassingStyleInitOperands( - LinalgOp op, MeshOp meshOp, ArrayRef spmdizedOperands, - ArrayRef reductionMeshAxes, IRMapping &spmdizationMap, + LinalgOp op, GridOp gridOp, ArrayRef partitionedOperands, + ArrayRef reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { // TODO: add support for multiple destination passing style initial value // operands. assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported."); - SmallVector newOperands = llvm::to_vector(spmdizedOperands); + SmallVector newOperands = llvm::to_vector(partitionedOperands); auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); - Value spmdizedInitOperand = - spmdizationMap.lookup(op->getOperands()[operandIdx]); + Value partitionedInitOperand = + partitionMap.lookup(op->getOperands()[operandIdx]); newOperands[operandIdx] = createDestinationPassingStyleInitOperand( - op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder); return newOperands; } static void createAllReduceForResultsWithoutPartialShardings( - LinalgOp unshardedOp, ArrayRef opReductionMeshAxes, - ArrayRef resultShardings, IRMapping &spmdizationMap, + LinalgOp unshardedOp, ArrayRef opReductionGridAxes, + ArrayRef resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); for (auto [unshardedLinalgOpResult, resultSharding] : llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { - Value spmdizedLinalgOpResult = - spmdizationMap.lookup(unshardedLinalgOpResult); - Value reducedValue = mesh::AllReduceOp::create( - builder, spmdizedLinalgOpResult, resultSharding.getMesh(), - opReductionMeshAxes, reductionKind); - spmdizationMap.map(unshardedLinalgOpResult, reducedValue); + Value partitionedLinalgOpResult = + partitionMap.lookup(unshardedLinalgOpResult); + Value reducedValue = shard::AllReduceOp::create( + builder, partitionedLinalgOpResult, resultSharding.getGrid(), + opReductionGridAxes, reductionKind); + partitionMap.map(unshardedLinalgOpResult, reducedValue); } } -static void spmdizeLinalgOpWithShardedReduction( - LinalgOp op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, +static void partitionLinalgOpWithShardedReduction( + LinalgOp op, ArrayRef partitionedOperands, + ArrayRef operandShardings, ArrayRef resultShardings, ArrayRef loopIteratorTypes, - ArrayRef> meshAxisAssignmentForLoopIterators, - IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, + ArrayRef> gridAxisAssignmentForLoopIterators, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder) { - MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); - SmallVector reductionMeshAxes = mesh::getReductionMeshAxes( - loopIteratorTypes, meshAxisAssignmentForLoopIterators); - SmallVector spmdizedLinalgOpOperands = - createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, - reductionMeshAxes, - spmdizationMap, builder); - // We must not change the operand mappings of the original spmdizationMap as - // they are the mappings for the whole spmdization blob and may be used by + GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable); + SmallVector reductionGridAxes = shard::getReductionGridAxes( + loopIteratorTypes, gridAxisAssignmentForLoopIterators); + SmallVector partitionedLinalgOpOperands = + createDestinationPassingStyleInitOperands(op, grid, partitionedOperands, + reductionGridAxes, partitionMap, + builder); + // We must not change the operand mappings of the original partitionMap as + // they are the mappings for the whole partition blob and may be used by // others. - IRMapping internalSpmdizationMap; - for (auto [unshardedOperand, spmdizedOperand] : - llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { - internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); + IRMapping internalPartitionMap; + for (auto [unshardedOperand, partitionedOperand] : + llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) { + internalPartitionMap.map(unshardedOperand, partitionedOperand); } - spmdizeTriviallyShardableOperation( - *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, - internalSpmdizationMap, symbolTable, builder); + partitionTriviallyShardableOperation( + *op, partitionedLinalgOpOperands, operandShardings, resultShardings, + internalPartitionMap, symbolTable, builder); for (Value result : op->getResults()) { - spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); + partitionMap.map(result, internalPartitionMap.lookup(result)); } // Handle partial shardings. createAllReduceForResultsWithoutPartialShardings( - op, reductionMeshAxes, resultShardings, spmdizationMap, builder); + op, reductionGridAxes, resultShardings, partitionMap, builder); } namespace { @@ -243,7 +242,7 @@ namespace { // permutations. template struct StructuredOpShardingInterface - : public mesh::ShardingInterface::ExternalModel< + : public shard::ShardingInterface::ExternalModel< StructuredOpShardingInterface, Op> { SmallVector getLoopIteratorTypes(Operation *op) const { return llvm::cast(op).getIteratorTypesArray(); @@ -272,16 +271,16 @@ struct StructuredOpShardingInterface [](unsigned count, utils::IteratorType iter) { return count + (iter == utils::IteratorType::reduction); }); - mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); + shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); return SmallVector(reductionItersCount, reductionKind); } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { LinalgOp linalgOp = llvm::cast(op); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); @@ -297,20 +296,20 @@ struct StructuredOpShardingInterface SmallVector loopIteratorTypes = linalgOp.getIteratorTypesArray(); - ShardingArray meshAxisAssignmentForLoopIterators = - getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, + ShardingArray gridAxisAssignmentForLoopIterators = + getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings, loopIteratorTypes, indexingMaps); - if (mesh::isAtLeastOneReductionIteratorSharded( - loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (shard::isAtLeastOneReductionIteratorSharded( + loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); - spmdizeLinalgOpWithShardedReduction( - linalgOp, spmdizedOperands, operandShardings, resultShardings, - loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, + partitionLinalgOpWithShardedReduction( + linalgOp, partitionedOperands, operandShardings, resultShardings, + loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap, symbolTable, implicitLocBuilder); } else { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, - operandShardings, resultShardings, - spmdizationMap, symbolTable, builder); + partitionTriviallyShardableOperation(*op, partitionedOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); } return success(); @@ -330,7 +329,7 @@ static void registerAll(MLIRContext *ctx) { (registerOne(ctx), ...); } -void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { +void registerShardingInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { DialectRegistry registry; registry.insert #include -#define DEBUG_TYPE "mesh-ops" +#define DEBUG_TYPE "shard-ops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; -#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc" namespace { @@ -74,11 +74,10 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { return lhs.value() * rhs.value(); } -SmallVector mlir::mesh::getMixedAsValues(OpBuilder b, - const Location &loc, - llvm::ArrayRef statics, - ValueRange dynamics, - Type type) { +SmallVector +mlir::shard::getMixedAsValues(OpBuilder b, const Location &loc, + llvm::ArrayRef statics, + ValueRange dynamics, Type type) { SmallVector values; auto dyn = dynamics.begin(); Type i64 = b.getI64Type(); @@ -102,7 +101,7 @@ SmallVector mlir::mesh::getMixedAsValues(OpBuilder b, //===----------------------------------------------------------------------===// namespace { -struct MeshInlinerInterface : public DialectInlinerInterface { +struct ShardInlinerinterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // Currently no restrictions are encoded for inlining. bool isLegalToInline(Operation *, Operation *, bool) const final { @@ -118,44 +117,45 @@ struct MeshInlinerInterface : public DialectInlinerInterface { } // namespace //===----------------------------------------------------------------------===// -// Mesh dialect +// Shard dialect //===----------------------------------------------------------------------===// -void MeshDialect::initialize() { +void ShardDialect::initialize() { addOperations< #define GET_OP_LIST -#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST -#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc" >(); - addInterface(); + addInterface(); } -Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { +Operation *ShardDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { return arith::ConstantOp::materialize(builder, value, type, loc); } //===----------------------------------------------------------------------===// -// Mesh utilities +// Shard utilities //===----------------------------------------------------------------------===// -static FailureOr getMeshAndVerify(Operation *op, - FlatSymbolRefAttr meshSymbol, +static FailureOr getGridAndVerify(Operation *op, + FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable) { - mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); - if (!mesh) { - return op->emitError() << "Undefined required mesh symbol \"" - << meshSymbol.getValue() << "\"."; + shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable); + if (!grid) { + return op->emitError() << "Undefined required grid symbol \"" + << gridSymbol.getValue() << "\"."; } - return mesh; + return grid; } template @@ -175,20 +175,20 @@ bool isUnique(It begin, It end) { return true; } -static LogicalResult verifyMeshAxes(Location loc, ArrayRef axes, - MeshOp mesh) { - SmallVector sorted = llvm::to_vector(axes); +static LogicalResult verifyGridAxes(Location loc, ArrayRef axes, + GridOp grid) { + SmallVector sorted = llvm::to_vector(axes); llvm::sort(sorted); if (!isUnique(sorted.begin(), sorted.end())) { - return emitError(loc) << "Mesh axes contains duplicate elements."; + return emitError(loc) << "Grid axes contains duplicate elements."; } - MeshAxis rank = mesh.getRank(); + GridAxis rank = grid.getRank(); for (auto axis : axes) { if (axis >= rank || axis < 0) { return emitError(loc) - << "0-based mesh axis index " << axis - << " is out of bounds. The referenced mesh \"" << mesh.getSymName() + << "0-based grid axis index " << axis + << " is out of bounds. The referenced grid \"" << grid.getSymName() << "\" is of rank " << rank << "."; } } @@ -197,22 +197,22 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef axes, } template -static FailureOr -getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { - auto mesh = - ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable); - if (failed(mesh)) { +static FailureOr +getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { + auto grid = + ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) { + if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) { return failure(); } - return mesh; + return grid; } -template -static void shardShape(const InShape &inShape, const MeshShape &meshShape, +static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef shardedDimsOffsets = {}, ArrayRef haloSizes = {}) { @@ -226,7 +226,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, llvm::adl_begin(outShape)); if (!shardedDimsOffsets.empty()) { - auto isDynShape = ShapedType::isDynamicShape(meshShape); + auto isDynShape = ShapedType::isDynamicShape(gridShape); uint64_t pos = 1; for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { if (!innerSplitAxes.empty()) { @@ -238,7 +238,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, // non-uniform offs in shardedDimsOffsets. uint64_t numShards = 0; for (auto i : innerSplitAxes.asArrayRef()) { - numShards += meshShape[i]; + numShards += gridShape[i]; } for (size_t i = 1; i < numShards; ++i) { if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] != @@ -256,7 +256,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { outShape[tensorAxis] = shardDimension( inShape[tensorAxis], - collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape)); + collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape)); } if (!haloSizes.empty()) { @@ -279,25 +279,25 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, } } -ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, - MeshSharding sharding) { +ShapedType shard::shardShapedType(ShapedType shape, GridOp grid, + Sharding sharding) { using Dim = std::decay_t; SmallVector resShapeArr(shape.getShape().size()); - shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), + shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(), resShapeArr, sharding.getStaticShardedDimsOffsets(), sharding.getStaticHaloSizes()); return shape.clone(resShapeArr); } -Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { +Type shard::shardType(Type type, GridOp grid, Sharding sharding) { RankedTensorType rankedTensorType = dyn_cast(type); if (rankedTensorType && !rankedTensorType.getShape().empty()) { - return shardShapedType(rankedTensorType, mesh, sharding); + return shardShapedType(rankedTensorType, grid, sharding); } return type; } -static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, +static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, @@ -336,9 +336,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); } -void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpResult result, - OpBuilder &builder) { +void mlir::shard::maybeInsertTargetShardingAnnotation(Sharding sharding, + OpResult result, + OpBuilder &builder) { ShardOp newShardOp; SmallVector> uses; for (auto &use : result.getUses()) { @@ -350,9 +350,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, } } -void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder) { +void mlir::shard::maybeInsertSourceShardingAnnotation(Sharding sharding, + OpOperand &operand, + OpBuilder &builder) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandSrcOp = operandValue.getDefiningOp(); @@ -404,18 +404,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, } //===----------------------------------------------------------------------===// -// mesh.mesh op +// shard.grid op //===----------------------------------------------------------------------===// -LogicalResult MeshOp::verify() { +LogicalResult GridOp::verify() { int64_t rank = getRank(); if (rank <= 0) - return emitOpError("rank of mesh is expected to be a positive integer"); + return emitOpError("rank of grid is expected to be a positive integer"); for (int64_t dimSize : getShape()) { if (dimSize < 0 && ShapedType::isStatic(dimSize)) - return emitOpError("dimension size of a mesh is expected to be " + return emitOpError("dimension size of a grid is expected to be " "non-negative or dynamic"); } @@ -423,21 +423,21 @@ LogicalResult MeshOp::verify() { } //===----------------------------------------------------------------------===// -// mesh.mesh_shape op +// shard.grid_shape op //===----------------------------------------------------------------------===// LogicalResult -MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { +GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) { return failure(); } size_t expectedResultsCount = - getAxes().empty() ? mesh->getRank() : getAxes().size(); + getAxes().empty() ? grid->getRank() : getAxes().size(); if (getResult().size() != expectedResultsCount) { return emitError() << "Unexpected number of results " << getResult().size() << ". Expected " << expectedResultsCount << "."; @@ -446,53 +446,53 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh) { - build(odsBuilder, odsState, mesh, SmallVector()); +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + GridOp grid) { + build(odsBuilder, odsState, grid, SmallVector()); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh, ArrayRef axes) { +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + GridOp grid, ArrayRef axes) { build(odsBuilder, odsState, - SmallVector(axes.empty() ? mesh.getRank() : axes.size(), + SmallVector(axes.empty() ? grid.getRank() : axes.size(), odsBuilder.getIndexType()), - mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); + grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes)); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef mesh, ArrayRef axes) { +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef grid, ArrayRef axes) { assert(!axes.empty()); build(odsBuilder, odsState, - SmallVector(axes.size(), odsBuilder.getIndexType()), mesh, - MeshAxesAttr::get(odsBuilder.getContext(), axes)); + SmallVector(axes.size(), odsBuilder.getIndexType()), grid, + GridAxesAttr::get(odsBuilder.getContext(), axes)); } -void MeshShapeOp::getAsmResultNames( +void GridShapeOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(getResults()[0], "mesh_shape"); + setNameFn(getResults()[0], "grid_shape"); } //===----------------------------------------------------------------------===// -// mesh.sharding +// shard.sharding //===----------------------------------------------------------------------===// void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr mesh, - ArrayRef split_axes, + FlatSymbolRefAttr grid, + ArrayRef split_axes, ArrayRef static_halos, ArrayRef static_offsets) { return build( - b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - llvm::StringRef mesh, ArrayRef split_axes, + llvm::StringRef grid, ArrayRef split_axes, ArrayRef static_halos, ArrayRef static_offsets) { - return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), - MeshAxesArrayAttr::get(b.getContext(), split_axes), + return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid), + GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); @@ -500,7 +500,7 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr mesh, ArrayRef split_axes, + FlatSymbolRefAttr grid, ArrayRef split_axes, ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { mlir::SmallVector staticHalos, staticDims; @@ -508,16 +508,16 @@ void ShardingOp::build( dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); return build( - b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - mlir::mesh::MeshSharding from) { + mlir::shard::Sharding from) { - build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(), - MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), + build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(), + GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), from.getStaticShardedDimsOffsets().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()), @@ -529,21 +529,21 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, } LogicalResult ShardingOp::verify() { - llvm::SmallSet visitedAxes; + llvm::SmallSet visitedAxes; - auto checkMeshAxis = [&](ArrayRef axesArray) -> LogicalResult { - for (MeshAxis axis : axesArray) { + auto checkGridAxis = [&](ArrayRef axesArray) -> LogicalResult { + for (GridAxis axis : axesArray) { if (axis < 0) - return emitError() << "mesh axis is expected to be non-negative"; + return emitError() << "grid axis is expected to be non-negative"; if (!visitedAxes.insert(axis).second) - return emitError() << "mesh axis duplicated"; + return emitError() << "grid axis duplicated"; } return success(); }; for (auto subAxes : getSplitAxes().getAxes()) { - ArrayRef subAxesArray = subAxes.asArrayRef(); - if (failed(checkMeshAxis(subAxesArray))) + ArrayRef subAxesArray = subAxes.asArrayRef(); + if (failed(checkGridAxis(subAxesArray))) return failure(); } @@ -572,26 +572,26 @@ void ShardingOp::getAsmResultNames( } LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (mlir::ShapedType::isDynamicShape(mesh->getShape()) && + if (mlir::ShapedType::isDynamicShape(grid->getShape()) && getStaticShardedDimsOffsets().size() > 0) { return emitError() << "sharded dims offsets are not allowed for " - "devices meshes with dynamic shape."; + "device grids with dynamic shape."; } auto shardedDimsOffsets = getStaticShardedDimsOffsets(); if (!shardedDimsOffsets.empty()) { - auto meshShape = mesh.value().getShape(); - assert(ShapedType::isStaticShape(meshShape)); + auto gridShape = grid.value().getShape(); + assert(ShapedType::isStaticShape(gridShape)); uint64_t pos = 0; for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) { if (!innerSplitAxes.empty()) { int64_t numShards = 0, off = 0; for (auto i : innerSplitAxes.asArrayRef()) { - numShards += meshShape[i]; + numShards += gridShape[i]; } for (int64_t i = 0; i <= numShards; ++i) { if (shardedDimsOffsets.size() <= pos + i) { @@ -684,11 +684,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// MeshSharding +// Sharding //===----------------------------------------------------------------------===// -bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const { - if (getMesh() != rhs.getMesh()) { +bool Sharding::equalSplitAxes(const Sharding &rhs) const { + if (getGrid() != rhs.getGrid()) { return false; } @@ -701,16 +701,16 @@ bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const { } return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize), - std::mem_fn(&MeshAxesAttr::empty)) && + std::mem_fn(&GridAxesAttr::empty)) && llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize), - std::mem_fn(&MeshAxesAttr::empty)); + std::mem_fn(&GridAxesAttr::empty)); } -bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const { +bool Sharding::equalHaloAndShardSizes(const Sharding &rhs) const { return equalShardSizes(rhs) && equalHaloSizes(rhs); } -bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { +bool Sharding::equalShardSizes(const Sharding &rhs) const { if (rhs.getStaticShardedDimsOffsets().size() != getStaticShardedDimsOffsets().size() || !llvm::equal(getStaticShardedDimsOffsets(), @@ -726,7 +726,7 @@ bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { return true; } -bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { +bool Sharding::equalHaloSizes(const Sharding &rhs) const { if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() || !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) { return false; @@ -738,45 +738,43 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { return true; } -bool MeshSharding::operator==(Value rhs) const { +bool Sharding::operator==(Value rhs) const { return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs); } -bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); } +bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); } -bool MeshSharding::operator==(const MeshSharding &rhs) const { +bool Sharding::operator==(const Sharding &rhs) const { return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs); } -bool MeshSharding::operator!=(const MeshSharding &rhs) const { - return !(*this == rhs); -} +bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); } -MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {} +Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {} -MeshSharding::MeshSharding(Value rhs) { +Sharding::Sharding(Value rhs) { auto shardingOp = rhs.getDefiningOp(); assert(shardingOp && "expected sharding op"); auto splitAxes = shardingOp.getSplitAxes().getAxes(); // If splitAxes are empty, use "empty" constructor. if (splitAxes.empty()) { - *this = MeshSharding(shardingOp.getMeshAttr()); + *this = Sharding(shardingOp.getGridAttr()); return; } *this = - get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(), + get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(), shardingOp.getStaticShardedDimsOffsets(), SmallVector(shardingOp.getDynamicHaloSizes()), SmallVector(shardingOp.getDynamicShardedDimsOffsets())); } -MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, - ArrayRef split_axes_, - ArrayRef static_halo_sizes_, - ArrayRef static_sharded_dims_offsets_, - ArrayRef dynamic_halo_sizes_, - ArrayRef dynamic_sharded_dims_offsets_) { - MeshSharding res(mesh_); +Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, + ArrayRef split_axes_, + ArrayRef static_halo_sizes_, + ArrayRef static_sharded_dims_offsets_, + ArrayRef dynamic_halo_sizes_, + ArrayRef dynamic_sharded_dims_offsets_) { + Sharding res(grid_); if (split_axes_.empty()) { return res; } @@ -784,7 +782,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, res.split_axes.resize(split_axes_.size()); for (auto [i, axis] : llvm::enumerate(split_axes_)) { res.split_axes[i] = - MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef()); + GridAxesAttr::get(grid_.getContext(), axis.asArrayRef()); } auto clone = [](const auto src, auto &dst) { @@ -801,7 +799,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, } //===----------------------------------------------------------------------===// -// mesh.shard_shape +// shard.shard_shape //===----------------------------------------------------------------------===// void ShardShapeOp::getAsmResultNames( @@ -820,7 +818,7 @@ void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, } //===----------------------------------------------------------------------===// -// mesh.shard op +// shard.shard op //===----------------------------------------------------------------------===// void ShardOp::getAsmResultNames( @@ -850,10 +848,10 @@ class FoldDuplicateShardOp final : public OpRewritePattern { if (!otherOp || !otherOp->isBeforeInBlock(op)) { return failure(); } - // Create a MeshSharding object for the current and the other ShardOp + // Create a Sharding object for the current and the other ShardOp // If the two are equal replace current op with the other op. - MeshSharding currentSharding(op.getSharding()); - MeshSharding otherSharding(otherOp.getSharding()); + Sharding currentSharding(op.getSharding()); + Sharding otherSharding(otherOp.getSharding()); if (currentSharding == otherSharding) { b.replaceAllUsesWith(op.getResult(), otherOp.getResult()); b.eraseOp(op.getOperation()); @@ -876,21 +874,21 @@ void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// mesh.process_multi_index op +// shard.process_multi_index op //===----------------------------------------------------------------------===// LogicalResult ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) { return failure(); } size_t expectedResultsCount = - getAxes().empty() ? mesh->getRank() : getAxes().size(); + getAxes().empty() ? grid->getRank() : getAxes().size(); if (getResult().size() != expectedResultsCount) { return emitError() << "Unexpected number of results " << getResult().size() << ". Expected " << expectedResultsCount << "."; @@ -900,17 +898,17 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh) { + GridOp grid) { build(odsBuilder, odsState, - SmallVector(mesh.getRank(), odsBuilder.getIndexType()), - mesh.getSymName(), ArrayRef()); + SmallVector(grid.getRank(), odsBuilder.getIndexType()), + grid.getSymName(), ArrayRef()); } void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef mesh, ArrayRef axes) { + StringRef grid, ArrayRef axes) { build(odsBuilder, odsState, - SmallVector(axes.size(), odsBuilder.getIndexType()), mesh, - MeshAxesAttr::get(odsBuilder.getContext(), axes)); + SmallVector(axes.size(), odsBuilder.getIndexType()), grid, + GridAxesAttr::get(odsBuilder.getContext(), axes)); } void ProcessMultiIndexOp::getAsmResultNames( @@ -919,21 +917,21 @@ void ProcessMultiIndexOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.process_linear_index op +// shard.process_linear_index op //===----------------------------------------------------------------------===// LogicalResult ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } return success(); } void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, - OperationState &odsState, MeshOp mesh) { - build(odsBuilder, odsState, mesh.getSymName()); + OperationState &odsState, GridOp grid) { + build(odsBuilder, odsState, grid.getSymName()); } void ProcessLinearIndexOp::getAsmResultNames( @@ -942,13 +940,13 @@ void ProcessLinearIndexOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.neighbors_linear_indices op +// shard.neighbors_linear_indices op //===----------------------------------------------------------------------===// LogicalResult NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } return success(); @@ -967,12 +965,12 @@ void NeighborsLinearIndicesOp::getAsmResultNames( namespace { template -struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern { +struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - auto meshAxes = op.getMeshAxes(); - if (!meshAxes.empty()) { + auto gridAxes = op.getGridAxes(); + if (!gridAxes.empty()) { return failure(); } if (op.getInput().getType() != op.getResult().getType()) { @@ -990,24 +988,24 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern { static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef device, Operation::operand_range deviceDynamic, - ArrayRef meshAxes, - ArrayRef meshShape) { - if (device.size() != meshAxes.size()) { + ArrayRef gridAxes, + ArrayRef gridShape) { + if (device.size() != gridAxes.size()) { return emitError(loc) << "In-group device \"" << deviceName << "\" has unexpected multi-index size " - << device.size() << ". Expected " << meshAxes.size() + << device.size() << ". Expected " << gridAxes.size() << "."; } for (size_t i = 0; i < device.size(); ++i) { if (ShapedType::isStatic(device[i]) && - ShapedType::isStatic(meshShape[meshAxes[i]]) && - meshShape[meshAxes[i]] <= device[i]) { + ShapedType::isStatic(gridShape[gridAxes[i]]) && + gridShape[gridAxes[i]] <= device[i]) { return emitError(loc) << "Out of bounds coordinate " << i << " for in-group device \"" << deviceName << "\"." << " Got " << device[i] << ", but expected value in the range [0, " - << (meshShape[meshAxes[i]] - 1) << "]."; + << (gridShape[gridAxes[i]] - 1) << "]."; } } return success(); @@ -1043,7 +1041,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc, static LogicalResult verifyGatherOperandAndResultShape( Value operand, Value result, int64_t gatherAxis, - ArrayRef meshAxes, ArrayRef meshShape) { + ArrayRef gridAxes, ArrayRef gridShape) { auto resultRank = cast(result.getType()).getRank(); if (gatherAxis < 0 || gatherAxis >= resultRank) { return emitError(result.getLoc()) @@ -1054,7 +1052,7 @@ static LogicalResult verifyGatherOperandAndResultShape( ShapedType operandType = cast(operand.getType()); ShapedType resultType = cast(result.getType()); auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); @@ -1070,7 +1068,7 @@ static LogicalResult verifyGatherOperandAndResultShape( static LogicalResult verifyAllToAllOperandAndResultShape( Value operand, Value result, int64_t splitAxis, int64_t concatAxis, - ArrayRef meshAxes, ArrayRef meshShape) { + ArrayRef gridAxes, ArrayRef gridShape) { ShapedType operandType = cast(operand.getType()); ShapedType resultType = cast(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { @@ -1088,7 +1086,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape( } auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis)); auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis)); DimensionSize expectedResultConcatDimSize = @@ -1115,7 +1113,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape( static LogicalResult verifyScatterOrSliceOperandAndResultShape( Value operand, Value result, int64_t tensorAxis, - ArrayRef meshAxes, ArrayRef meshShape) { + ArrayRef gridAxes, ArrayRef gridShape) { ShapedType operandType = cast(operand.getType()); ShapedType resultType = cast(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { @@ -1129,7 +1127,7 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape( } auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); auto operandScatterDimSize = DimensionSize(operandType.getDimSize(tensorAxis)); if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && @@ -1151,8 +1149,8 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape( return success(); } -static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, - ArrayRef meshAxes, +static RankedTensorType sliceResultType(Type operandType, GridOp grid, + ArrayRef gridAxes, int64_t sliceAxis) { RankedTensorType operandRankedTensorType = cast(operandType); @@ -1163,29 +1161,29 @@ static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, resultShape[sliceAxis] = operandSliceAxisSize / - DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); + DimensionSize(collectiveProcessGroupSize(gridAxes, grid)); return operandRankedTensorType.clone(resultShape); } //===----------------------------------------------------------------------===// -// mesh.all_gather op +// shard.all_gather op //===----------------------------------------------------------------------===// LogicalResult AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } auto gatherAxis = getGatherAxis().getSExtValue(); return verifyGatherOperandAndResultShape(getOperand(), getResult(), - gatherAxis, getMeshAxes(), - mesh.value().getShape()); + gatherAxis, getGridAxes(), + grid.value().getShape()); } void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void AllGatherOp::getAsmResultNames( @@ -1194,23 +1192,23 @@ void AllGatherOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_reduce op +// shard.all_reduce op //===----------------------------------------------------------------------===// LogicalResult AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - return getMeshAndVerifyAxes(*this, symbolTable); + return getGridAndVerifyAxes(*this, symbolTable); } void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Value input, StringRef mesh, - ArrayRef meshAxes, ReductionKind reduction) { - build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, + Value input, StringRef grid, + ArrayRef gridAxes, ReductionKind reduction) { + build(odsBuilder, odsState, input.getType(), grid, gridAxes, input, reduction); } @@ -1220,36 +1218,36 @@ void AllReduceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_slice op +// shard.all_slice op //===----------------------------------------------------------------------===// LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyScatterOrSliceOperandAndResultShape( - getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), - mesh.value().getShape()); + getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(), + grid.value().getShape()); } void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Value input, MeshOp mesh, ArrayRef meshAxes, + Value input, GridOp grid, ArrayRef gridAxes, int64_t sliceAxis) { - Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); - build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, + Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis); + build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes, sliceAxis); } void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Type resultType, Value input, StringRef mesh, - ArrayRef meshAxes, int64_t sliceAxis) { - build(odsBuilder, odsState, resultType, mesh, meshAxes, input, + Type resultType, Value input, StringRef grid, + ArrayRef gridAxes, int64_t sliceAxis) { + build(odsBuilder, odsState, resultType, grid, gridAxes, input, APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); } @@ -1259,23 +1257,23 @@ void AllSliceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_to_all op +// shard.all_to_all op //===----------------------------------------------------------------------===// LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyAllToAllOperandAndResultShape( getOperand(), getResult(), getSplitAxis().getSExtValue(), - getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); + getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape()); } void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void AllToAllOp::getAsmResultNames( @@ -1284,18 +1282,18 @@ void AllToAllOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.broadcast op +// shard.broadcast op //===----------------------------------------------------------------------===// LogicalResult BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } @@ -1304,7 +1302,7 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void BroadcastOp::getAsmResultNames( @@ -1313,29 +1311,29 @@ void BroadcastOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.gather op +// shard.gather op //===----------------------------------------------------------------------===// LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } auto gatherAxis = getGatherAxis().getSExtValue(); return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, - getMeshAxes(), - mesh.value().getShape()); + getGridAxes(), + grid.value().getShape()); } void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void GatherOp::getAsmResultNames( @@ -1344,18 +1342,18 @@ void GatherOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.recv op +// shard.recv op //===----------------------------------------------------------------------===// LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (getSource() && failed(verifyInGroupDevice(getLoc(), getSourceAttrName(), getSource().value(), getSourceDynamic(), - getMeshAxes(), mesh.value().getShape()))) { + getGridAxes(), grid.value().getShape()))) { return failure(); } return success(); @@ -1363,7 +1361,7 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void RecvOp::getAsmResultNames(function_ref setNameFn) { @@ -1371,17 +1369,17 @@ void RecvOp::getAsmResultNames(function_ref setNameFn) { } //===----------------------------------------------------------------------===// -// mesh.reduce op +// shard.reduce op //===----------------------------------------------------------------------===// LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } @@ -1390,7 +1388,7 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void ReduceOp::getAsmResultNames( @@ -1399,24 +1397,24 @@ void ReduceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.reduce_scatter op +// shard.reduce_scatter op //===----------------------------------------------------------------------===// LogicalResult ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyScatterOrSliceOperandAndResultShape( - getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), - mesh.value().getShape()); + getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(), + grid.value().getShape()); } void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void ReduceScatterOp::getAsmResultNames( @@ -1425,29 +1423,29 @@ void ReduceScatterOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.scatter op +// shard.scatter op //===----------------------------------------------------------------------===// LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } auto scatterAxis = getScatterAxis().getSExtValue(); return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), - scatterAxis, getMeshAxes(), - mesh.value().getShape()); + scatterAxis, getGridAxes(), + grid.value().getShape()); } void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void ScatterOp::getAsmResultNames( @@ -1456,17 +1454,17 @@ void ScatterOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.send op +// shard.send op //===----------------------------------------------------------------------===// LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), getDestination(), getDestinationDynamic(), - getMeshAxes(), mesh.value().getShape()))) { + getGridAxes(), grid.value().getShape()))) { return failure(); } return success(); @@ -1474,7 +1472,7 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add>(context); } void SendOp::getAsmResultNames(function_ref setNameFn) { @@ -1482,20 +1480,20 @@ void SendOp::getAsmResultNames(function_ref setNameFn) { } //===----------------------------------------------------------------------===// -// mesh.shift op +// shard.shift op //===----------------------------------------------------------------------===// LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } - auto meshAxes = getMeshAxes(); + auto gridAxes = getGridAxes(); auto shiftAxis = getShiftAxis().getZExtValue(); - if (!llvm::is_contained(meshAxes, shiftAxis)) { + if (!llvm::is_contained(gridAxes, shiftAxis)) { return emitError() << "Invalid shift axis " << shiftAxis - << ". It must be one of the grouping mesh axes."; + << ". It must be one of the grouping grid axes."; } return success(); @@ -1504,7 +1502,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // TODO: remove op when offset is 0 or if it is a rotate with and - // offset % shift_axis_mesh_dim_size == 0. + // offset % shift_axis_grid_dim_size == 0. } void ShiftOp::getAsmResultNames( @@ -1513,13 +1511,13 @@ void ShiftOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.update_halo op +// shard.update_halo op //===----------------------------------------------------------------------===// LogicalResult UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } @@ -1531,12 +1529,12 @@ UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { //===----------------------------------------------------------------------===// #define GET_OP_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc" #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc" -#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt similarity index 76% rename from mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt rename to mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt index afe76b539846a..01e8e56dd391d 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_library(MLIRShardingInterface ShardingInterface.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard DEPENDS MLIRShardingInterfaceIncGen @@ -10,7 +10,7 @@ add_mlir_library(MLIRShardingInterface LINK_LIBS PUBLIC MLIRDialectUtils MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRTensorDialect MLIRSupport ) diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp similarity index 70% rename from mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp rename to mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp index 6b3d49e08b549..d4e76189f7b8a 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" @@ -24,9 +24,9 @@ #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc" //===----------------------------------------------------------------------===// // common util functions @@ -93,40 +93,39 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { } template -SmallVector +SmallVector fromArrayOfVector(MLIRContext *ctxt, const SmallVector> &vec) { - SmallVector res; + SmallVector res; for (const auto &v : vec) { - res.emplace_back(MeshAxesAttr::get(ctxt, v)); + res.emplace_back(GridAxesAttr::get(ctxt, v)); } return res; } //===----------------------------------------------------------------------===// -// mesh::getMeshSharding +// shard::getSharding //===----------------------------------------------------------------------===// -FailureOr> -mesh::getMeshSharding(OpResult result) { +FailureOr> shard::getSharding(OpResult result) { Value val = cast(result); bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { - auto shardOp = llvm::dyn_cast(user); + auto shardOp = llvm::dyn_cast(user); if (!shardOp) return false; return !shardOp.getAnnotateForUsers(); }); if (anyShardedForDef) { - // expected to have exact one use if it has a use of `mesh.shard` without + // expected to have exact one use if it has a use of `shard.shard` without // unit attr annotate_for_users if (!val.hasOneUse()) return failure(); - auto shardOp = llvm::cast(*val.getUsers().begin()); - return std::make_pair(false, MeshSharding(shardOp.getSharding())); + auto shardOp = llvm::cast(*val.getUsers().begin()); + return std::make_pair(false, Sharding(shardOp.getSharding())); } bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) { - auto shardOp = llvm::dyn_cast(user); + auto shardOp = llvm::dyn_cast(user); if (!shardOp) return false; return shardOp.getAnnotateForUsers(); @@ -138,24 +137,23 @@ mesh::getMeshSharding(OpResult result) { if (shardOp) shardOps.push_back(shardOp); } - MeshSharding shardForDef = shardOps[0].getSharding(); + Sharding shardForDef = shardOps[0].getSharding(); for (size_t i = 1; i < shardOps.size(); ++i) { - // TODO: Deduce a reasonable mesh sharding attr for def when they are + // TODO: Deduce a reasonable grid sharding attr for def when they are // different assert(shardForDef == shardOps[i].getSharding() && - "only support all shard ops have the same mesh sharding attr"); + "only support all shard ops have the same grid sharding attr"); } return std::make_pair(true, shardForDef); } return failure(); } -FailureOr> -mesh::getMeshSharding(OpOperand &opOperand) { +FailureOr> shard::getSharding(OpOperand &opOperand) { Value val = opOperand.get(); if (ShardOp shardOp = val.getDefiningOp()) return std::make_pair(shardOp.getAnnotateForUsers(), - MeshSharding(shardOp.getSharding())); + Sharding(shardOp.getSharding())); return failure(); } @@ -164,7 +162,7 @@ mesh::getMeshSharding(OpOperand &opOperand) { // ShardingInterface::verifyShardingInterfaceImpl //===----------------------------------------------------------------------===// -LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { +LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() { Operation *op = getOperation(); // check operands and results type @@ -201,7 +199,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { // ShardingInterface::printLoopTypesAndIndexingMaps //===----------------------------------------------------------------------===// -void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { +void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { os << "print loop types and indexing maps for: \n"; getOperation()->print(os); os << "\n"; @@ -222,15 +220,15 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { namespace { -// Update the given `shardingOption` according to `meshAxes` and `loopIdx` +// Update the given `shardingOption` according to `gridAxes` and `loopIdx` static LogicalResult fillShardingOption(Operation *op, ShardingOption &shardingOption, - FlatSymbolRefAttr mesh, - ArrayRef meshAxes, + FlatSymbolRefAttr grid, + ArrayRef gridAxes, unsigned loopIdx) { - if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || + if ((shardingOption.grid && grid && shardingOption.grid != grid) || (!shardingOption.shardingArray[loopIdx].empty() && - shardingOption.shardingArray[loopIdx] != meshAxes)) { + shardingOption.shardingArray[loopIdx] != gridAxes)) { LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " << loopIdx << "\n"); return failure(); @@ -239,28 +237,28 @@ static LogicalResult fillShardingOption(Operation *op, if (i == loopIdx) continue; - for (MeshAxis axis : meshAxes) { + for (GridAxis axis : gridAxes) { if (llvm::is_contained(shardingOption.shardingArray[i], axis)) { - LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " + LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes " << axis << " duplicate"); return failure(); } } } - if (mesh) - shardingOption.mesh = mesh; + if (grid) + shardingOption.grid = grid; if (shardingOption.shardingArray[loopIdx].empty()) - shardingOption.shardingArray[loopIdx].append(meshAxes.begin(), - meshAxes.end()); + shardingOption.shardingArray[loopIdx].append(gridAxes.begin(), + gridAxes.end()); return success(); } } // namespace FailureOr -mesh::detail::defaultGetShardingOption(Operation *op, - ArrayRef operandShardings, - ArrayRef resultShardings) { +shard::detail::defaultGetShardingOption(Operation *op, + ArrayRef operandShardings, + ArrayRef resultShardings) { ShardingInterface shardingOp = llvm::cast(op); ShardingOption shardingOption; @@ -276,25 +274,25 @@ mesh::detail::defaultGetShardingOption(Operation *op, // 1. Fill sharding option based on op results for (auto shardingIt : llvm::enumerate(resultShardings)) { - MeshSharding shardAttr = shardingIt.value(); + Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; AffineMap map = maps[numOperands + shardingIt.index()]; anyShardingInResultsOrOperands = true; if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { - shardingOption.mesh = shardAttr.getMeshAttr(); + shardingOption.grid = shardAttr.getGridAttr(); } else { // Handle the split axes: calculate the corresponding loop index for each // split axes sub-array, and then store the sub-array to // shardingOption[index] for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); - ArrayRef axes = std::get<1>(it).asArrayRef(); + ArrayRef axes = std::get<1>(it).asArrayRef(); auto dim = cast(expr); unsigned index = dim.getPosition(); visitedLoopIndices.insert(index); if (failed(fillShardingOption(op, shardingOption, - shardAttr.getMeshAttr(), axes, index))) + shardAttr.getGridAttr(), axes, index))) return failure(); } } @@ -302,7 +300,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, // 2. Fill sharding option based on operands for (auto shardingIt : llvm::enumerate(operandShardings)) { - MeshSharding shardAttr = shardingIt.value(); + Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; @@ -316,7 +314,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, // then the operands with multiple loop indices. for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); - ArrayRef axes = std::get<1>(it).asArrayRef(); + ArrayRef axes = std::get<1>(it).asArrayRef(); FailureOr> loopIndices = checkOperandAffineExpr(expr, numDims); if (failed(loopIndices)) @@ -329,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, unsigned loopIdx = *loopIndices->begin(); visitedLoopIndices.insert(loopIdx); if (failed(fillShardingOption(op, shardingOption, - shardAttr.getMeshAttr(), axes, loopIdx))) + shardAttr.getGridAttr(), axes, loopIdx))) return failure(); } // If multiple loop indices correspond to a dimension of an operand, it is @@ -361,11 +359,11 @@ mesh::detail::defaultGetShardingOption(Operation *op, } // Get the sharding attributed for the given result and sharding option. -MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, - AffineMap map, - ArrayRef loopTypes) { +static Sharding getSharding(OpResult result, + const ShardingOption &shardingOption, AffineMap map, + ArrayRef loopTypes) { auto resultType = cast(result.getType()); - SmallVector> splitAxes(resultType.getRank()); + SmallVector> splitAxes(resultType.getRank()); // process the split axes for (auto it : llvm::enumerate(map.getResults())) { @@ -379,25 +377,25 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, } removeTrailingEmptySubArray(splitAxes); - return MeshSharding::get(shardingOption.mesh, - fromArrayOfVector(result.getContext(), splitAxes)); + return Sharding::get(shardingOption.grid, + fromArrayOfVector(result.getContext(), splitAxes)); } -static FailureOr getSharding(OpOperand &opOperand, - const ShardingOption &shardingOption, - AffineMap map) { +static FailureOr getSharding(OpOperand &opOperand, + const ShardingOption &shardingOption, + AffineMap map) { Value operandValue = opOperand.get(); auto operandType = dyn_cast(operandValue.getType()); if (!operandType) { if (operandValue.getType().isIntOrIndexOrFloat()) - return MeshSharding(); + return Sharding(); return failure(); } // 0d tensors cannot be sharded and must get replicated if (operandType.getRank() == 0) { - return MeshSharding(shardingOption.mesh); + return Sharding(shardingOption.grid); } - SmallVector> splitAxes(operandType.getRank()); + SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { int64_t idx = it.index(); @@ -422,15 +420,14 @@ static FailureOr getSharding(OpOperand &opOperand, } removeTrailingEmptySubArray(splitAxes); - return MeshSharding::get( - shardingOption.mesh, + return Sharding::get( + shardingOption.grid, fromArrayOfVector(opOperand.get().getContext(), splitAxes)); } -FailureOr> -mesh::detail::defaultGetShardingAnnotations( +FailureOr> shard::detail::defaultGetShardingAnnotations( Operation *op, const ShardingOption &shardingOption) { - std::vector res; + std::vector res; ShardingInterface shardingOp = llvm::cast(op); SmallVector loopTypes = @@ -439,7 +436,7 @@ mesh::detail::defaultGetShardingAnnotations( unsigned numOperands = op->getNumOperands(); for (OpOperand &opOperand : op->getOpOperands()) { - FailureOr shardingAttr = getSharding( + FailureOr shardingAttr = ::getSharding( opOperand, shardingOption, maps[opOperand.getOperandNumber()]); if (failed(shardingAttr)) return failure(); @@ -447,9 +444,9 @@ mesh::detail::defaultGetShardingAnnotations( } for (OpResult result : op->getResults()) { - res.push_back(getSharding(result, shardingOption, - maps[numOperands + result.getResultNumber()], - loopTypes)); + res.push_back(::getSharding(result, shardingOption, + maps[numOperands + result.getResultNumber()], + loopTypes)); } return res; @@ -459,26 +456,25 @@ mesh::detail::defaultGetShardingAnnotations( // detail::defaultAddShardingAnnotations //===----------------------------------------------------------------------===// -// To add a `mesh.shard` op for the given result, based on the details provided +// To add a `shard.shard` op for the given result, based on the details provided // in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef loopTypes) { - MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes); + Sharding sharding = getSharding(result, shardingOption, map, loopTypes); maybeInsertTargetShardingAnnotation(sharding, result, b); return success(); } -// To add a `mesh.shard` op for the given operand, based on the details provided -// in `shardingOption`, `map`, and `loopTypes`. +// To add a `shard.shard` op for the given operand, based on the details +// provided in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { - FailureOr sharding = - getSharding(opOperand, shardingOption, map); + FailureOr sharding = getSharding(opOperand, shardingOption, map); if (failed(sharding)) { return failure(); } @@ -488,9 +484,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, return success(); } -LogicalResult mesh::detail::defaultAddShardingAnnotations( +LogicalResult shard::detail::defaultAddShardingAnnotations( Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { - assert(!shardingOption.empty && shardingOption.mesh); + assert(!shardingOption.empty && shardingOption.grid); ShardingInterface shardingOp = llvm::cast(op); SmallVector loopTypes = @@ -498,7 +494,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( SmallVector maps = shardingOp.getIndexingMaps(); unsigned numOperands = op->getNumOperands(); - // 1. add mesh.shard ops for all op results + // 1. add shard.shard ops for all op results for (OpResult result : op->getResults()) { if (failed(addShardOp(b, result, shardingOption, maps[numOperands + result.getResultNumber()], @@ -506,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( return failure(); } - // 2. add mesh.shard ops for all operands + // 2. add shard.shard ops for all operands for (OpOperand &opOperand : op->getOpOperands()) { if (failed(addShardOp(b, opOperand, shardingOption, maps[opOperand.getOperandNumber()]))) @@ -517,9 +513,8 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( } #ifndef NDEBUG -static bool -isValueCompatibleWithFullReplicationSharding(Value value, - MeshSharding sharding) { +static bool isValueCompatibleWithFullReplicationSharding(Value value, + Sharding sharding) { if (isa(value.getType())) { return isFullReplication(sharding); } @@ -527,60 +522,59 @@ isValueCompatibleWithFullReplicationSharding(Value value, return !sharding; } -template +template static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, - MeshShardingRage &&shardings) { + ShardingRage &&shardings) { if (std::size(values) != std::size(shardings)) { return false; } - return llvm::all_of( - llvm::zip_equal(std::forward(values), - std::forward(shardings)), - [](auto valueAndSharding) { - return isValueCompatibleWithFullReplicationSharding( - std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); - }); + return llvm::all_of(llvm::zip_equal(std::forward(values), + std::forward(shardings)), + [](auto valueAndSharding) { + return isValueCompatibleWithFullReplicationSharding( + std::get<0>(valueAndSharding), + std::get<1>(valueAndSharding)); + }); } #endif // NDEBUG -void mesh::spmdizeFullyReplicatedOperation( - Operation &op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, OpBuilder &builder) { - assert(spmdizedOperands.size() == operandShardings.size()); +void shard::partitionFullyReplicatedOperation( + Operation &op, ArrayRef partitionedOperands, + ArrayRef operandShardings, ArrayRef resultShardings, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, + OpBuilder &builder) { + assert(partitionedOperands.size() == operandShardings.size()); assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), operandShardings)); assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), resultShardings)); // `clone` will populate the mapping of old to new results. - builder.clone(op, spmdizationMap); + builder.clone(op, partitionMap); } -static void updateMeshAxisAssignmentForLoopIterators( - ArrayRef meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, - SmallVector>> - &meshAxesAssignmentForLoopIterators) { +static void updateGridAxisAssignmentForLoopIterators( + ArrayRef gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, + SmallVector>> + &gridAxesAssignmentForLoopIterators) { AffineDimExpr affineDimExpr = cast(indexingExpr); unsigned loopIteratorIdx = affineDimExpr.getPosition(); - if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { - assert(llvm::equal(meshAxesAssignmentForTensorAxis, - *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); + if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) { + assert(llvm::equal(gridAxesAssignmentForTensorAxis, + *gridAxesAssignmentForLoopIterators[loopIteratorIdx])); } else { - meshAxesAssignmentForLoopIterators[loopIteratorIdx] = - llvm::to_vector(meshAxesAssignmentForTensorAxis); + gridAxesAssignmentForLoopIterators[loopIteratorIdx] = + llvm::to_vector(gridAxesAssignmentForTensorAxis); } } -ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( - ArrayRef operandShardings, - ArrayRef resultShardings, +ShardingArray shard::getGridAxisAssignmentForLoopIterators( + ArrayRef operandShardings, ArrayRef resultShardings, ArrayRef loopIteratorTypes, ArrayRef indexingMaps) { - SmallVector>> - meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); - std::vector operatorAndResultShardings; + SmallVector>> + gridAxisAssignmentForLoopIterators(loopIteratorTypes.size()); + std::vector operatorAndResultShardings; operatorAndResultShardings.reserve(operandShardings.size() + resultShardings.size()); llvm::append_range(operatorAndResultShardings, operandShardings); @@ -589,69 +583,69 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( if (!sharding) { continue; } - for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : + for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] : llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { - updateMeshAxisAssignmentForLoopIterators( - meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, - meshAxisAssignmentForLoopIterators); + updateGridAxisAssignmentForLoopIterators( + gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, + gridAxisAssignmentForLoopIterators); } // Missing trailing split axes means replication on those tensor dimensions. for (unsigned i = sharding.getSplitAxes().size(); i < affineMap.getNumResults(); ++i) { - updateMeshAxisAssignmentForLoopIterators( - {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); + updateGridAxisAssignmentForLoopIterators( + {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators); } } ShardingArray res; - llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), - [](std::optional> &axes) { + llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res), + [](std::optional> &axes) { if (!axes) { - return SmallVector(); + return SmallVector(); }; return std::move(*axes); }); return res; } -bool mesh::isAtLeastOneReductionIteratorSharded( +bool shard::isAtLeastOneReductionIteratorSharded( ArrayRef loopIteratorTypes, - ArrayRef> meshAxisAssignmentForLoopIterators) { - for (auto [loopIteratorType, meshAxisAssignment] : - llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ArrayRef> gridAxisAssignmentForLoopIterators) { + for (auto [loopIteratorType, gridAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction && - !meshAxisAssignment.empty()) { + !gridAxisAssignment.empty()) { return true; } } return false; } -SmallVector mesh::getReductionMeshAxes( +SmallVector shard::getReductionGridAxes( ArrayRef loopIteratorTypes, - ArrayRef> meshAxisAssignmentForLoopIterators) { - SmallVector meshAxes; - for (auto [loopIteratorType, meshAxisAssignment] : - llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ArrayRef> gridAxisAssignmentForLoopIterators) { + SmallVector gridAxes; + for (auto [loopIteratorType, gridAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction) { - llvm::append_range(meshAxes, meshAxisAssignment); + llvm::append_range(gridAxes, gridAxisAssignment); } } - return meshAxes; + return gridAxes; } -void mesh::spmdizeTriviallyShardableOperation( - Operation &op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, OpBuilder &builder) { +void shard::partitionTriviallyShardableOperation( + Operation &op, ArrayRef partitionedOperands, + ArrayRef operandShardings, ArrayRef resultShardings, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, + OpBuilder &builder) { // `clone` will populate the mapping of old to new results. - Operation *newOp = builder.clone(op, spmdizationMap); + Operation *newOp = builder.clone(op, partitionMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType( newResult.getType(), - getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding)); + getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding)); } } diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt similarity index 73% rename from mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt rename to mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt index 381bc9afede07..a884764e70e92 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt @@ -1,14 +1,14 @@ -add_mlir_dialect_library(MLIRMeshTransforms +add_mlir_dialect_library(MLIRShardTransforms Simplifications.cpp ShardingPropagation.cpp - Spmdization.cpp + Partition.cpp Transforms.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard DEPENDS - MLIRMeshPassIncGen + MLIRShardPassIncGen MLIRShardingInterface LINK_LIBS PUBLIC @@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRMeshTransforms MLIRFuncDialect MLIRFunctionInterfaces MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRPass MLIRSupport MLIRTensorDialect diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp similarity index 66% rename from mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp rename to mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dd744d0da5c7..5fe55669c90db 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -1,4 +1,4 @@ -//===- Spmdization.cpp --------------------------------------------- C++ --===// +//===- Partition.cpp --------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Spmdization.h" +#include "mlir/Dialect/Shard/Transforms/Partition.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -33,7 +33,7 @@ #include #include -namespace mlir::mesh { +namespace mlir::shard { template static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, @@ -43,52 +43,51 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, }); } -static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t splitTensorAxis, - MeshAxis splitMeshAxis) { - SmallVector targetShardingSplitAxes = +static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t splitTensorAxis, + GridAxis splitGridAxis) { + SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= splitTensorAxis) { - targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); + targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {})); } auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); - targetSplitAxes.push_back(splitMeshAxis); + targetSplitAxes.push_back(splitGridAxis); targetShardingSplitAxes[splitTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } -// Split a replicated tensor along a mesh axis. +// Split a replicated tensor along a grid axis. // E.g. [[0, 1]] -> [[0, 1, 2]]. -// Returns the spmdized target value with its sharding. -static std::tuple, MeshSharding> +// Returns the partitioned target value with its sharding. +static std::tuple, Sharding> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, - MeshSharding sourceSharding, - TypedValue sourceShard, MeshOp mesh, - int64_t splitTensorAxis, MeshAxis splitMeshAxis) { + Sharding sourceSharding, + TypedValue sourceShard, GridOp grid, + int64_t splitTensorAxis, GridAxis splitGridAxis) { TypedValue targetShard = cast>( builder - .create(sourceShard, mesh, - ArrayRef(splitMeshAxis), + .create(sourceShard, grid, + ArrayRef(splitGridAxis), splitTensorAxis) .getResult()); - MeshSharding targetSharding = targetShardingInSplitLastAxis( - builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); + Sharding targetSharding = targetShardingInSplitLastAxis( + builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; } // Detect if the resharding is of type e.g. // [[0, 1]] -> [[0, 1, 2]]. -// If detected, returns the corresponding tensor axis mesh axis pair. +// If detected, returns the corresponding tensor axis grid axis pair. // Does not detect insertions like // [[0, 1]] -> [[0, 2, 1]]. -static std::optional> -detectSplitLastAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +static std::optional> +detectSplitLastAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); ++tensorAxis) { if (sourceSharding.getSplitAxes().size() > tensorAxis) { @@ -118,16 +117,15 @@ detectSplitLastAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static std::optional, MeshSharding>> -trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional, Sharding>> +trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue sourceShard) { if (auto detectRes = detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { - auto [tensorAxis, meshAxis] = detectRes.value(); - return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, - tensorAxis, meshAxis); + auto [tensorAxis, gridAxis] = detectRes.value(); + return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid, + tensorAxis, gridAxis); } return std::nullopt; @@ -135,10 +133,10 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // Detect if the resharding is of type e.g. // [[0, 1, 2]] -> [[0, 1]]. -// If detected, returns the corresponding tensor axis mesh axis pair. -static std::optional> -detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +// If detected, returns the corresponding tensor axis grid axis pair. +static std::optional> +detectUnsplitLastAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); ++tensorAxis) { if (targetSharding.getSplitAxes().size() > tensorAxis) { @@ -165,10 +163,10 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t splitTensorAxis) { - SmallVector targetShardingSplitAxes = +static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t splitTensorAxis) { + SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); assert(static_cast(targetShardingSplitAxes.size()) > splitTensorAxis); @@ -177,9 +175,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, targetSplitAxes.pop_back(); targetShardingSplitAxes[splitTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } static ShapedType allGatherResultShapeInUnsplitLastAxis( @@ -190,45 +187,42 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis( return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } -static std::tuple, MeshSharding> -unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, - MeshSharding sourceSharding, - ShapedType sourceUnshardedShape, - TypedValue sourceShard, MeshOp mesh, - int64_t splitTensorAxis, MeshAxis splitMeshAxis) { +static std::tuple, Sharding> unsplitLastAxisInResharding( + ImplicitLocOpBuilder &builder, Sharding sourceSharding, + ShapedType sourceUnshardedShape, TypedValue sourceShard, + GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); - MeshSharding targetSharding = + Sharding targetSharding = targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( - sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); + sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis); Value allGatherResult = AllGatherOp::create( builder, RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), - mesh.getSymName(), SmallVector({splitMeshAxis}), sourceShard, + grid.getSymName(), SmallVector({splitGridAxis}), sourceShard, APInt(64, splitTensorAxis)); ShapedType targetShape = - shardShapedType(sourceUnshardedShape, mesh, targetSharding); + shardShapedType(sourceUnshardedShape, grid, targetSharding); TypedValue targetShard = cast>( tensor::CastOp::create(builder, targetShape, allGatherResult) .getResult()); return {targetShard, targetSharding}; } -static std::optional, MeshSharding>> -tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional, Sharding>> +tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { - auto [tensorAxis, meshAxis] = detectRes.value(); + auto [tensorAxis, gridAxis] = detectRes.value(); return unsplitLastAxisInResharding(builder, sourceSharding, - sourceUnshardedShape, sourceShard, mesh, - tensorAxis, meshAxis); + sourceUnshardedShape, sourceShard, grid, + tensorAxis, gridAxis); } return std::nullopt; @@ -238,10 +232,10 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // [[0, 1], [2]] -> [[0], [1, 2]]. // Only moving the last axis counts. // If detected, returns the corresponding (source_tensor_axis, -// target_tensor_axis, mesh_axis) tuple. -static std::optional> -detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +// target_tensor_axis, grid_axis) tuple. +static std::optional> +detectMoveLastSplitAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t sourceTensorAxis = 0; sourceTensorAxis < sourceSharding.getSplitAxes().size(); ++sourceTensorAxis) { @@ -281,33 +275,32 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t sourceTensorAxis, - int64_t targetTensorAxis) { - SmallVector targetShardingSplitAxes = +static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t sourceTensorAxis, + int64_t targetTensorAxis) { + SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= targetTensorAxis) { - targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); + targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {})); } auto sourceSplitAxes = llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); assert(!sourceSplitAxes.empty()); - auto meshAxis = sourceSplitAxes.back(); + auto gridAxis = sourceSplitAxes.back(); sourceSplitAxes.pop_back(); targetShardingSplitAxes[sourceTensorAxis] = - MeshAxesAttr::get(ctx, sourceSplitAxes); + GridAxesAttr::get(ctx, sourceSplitAxes); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); - targetSplitAxes.push_back(meshAxis); + targetSplitAxes.push_back(gridAxis); targetShardingSplitAxes[targetTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, @@ -322,46 +315,46 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } -static std::tuple, MeshSharding> -moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, +static std::tuple, Sharding> +moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, int64_t sourceTensorAxis, - int64_t targetTensorAxis, MeshAxis meshAxis) { + int64_t targetTensorAxis, GridAxis gridAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); - MeshSharding targetSharding = targetShardingInMoveLastAxis( + Sharding targetSharding = targetShardingInMoveLastAxis( ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( - sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, + sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis, targetTensorAxis); Value allToAllResult = AllToAllOp::create( builder, RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), - mesh.getSymName(), SmallVector({meshAxis}), sourceShard, + grid.getSymName(), SmallVector({gridAxis}), sourceShard, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = - shardShapedType(sourceUnshardedShape, mesh, targetSharding); + shardShapedType(sourceUnshardedShape, grid, targetSharding); TypedValue targetShard = cast>( tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } -static std::optional, MeshSharding>> -tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional, Sharding>> +tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, + Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { - auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); + auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value(); return moveLastSplitAxisInResharding( - builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, - sourceTensorAxis, targetTensorAxis, meshAxis); + builder, grid, sourceSharding, sourceUnshardedShape, sourceShard, + sourceTensorAxis, targetTensorAxis, gridAxis); } return std::nullopt; @@ -371,10 +364,9 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // needed. A changed halo sizes requires copying the "core" of the source tensor // into the "core" of the destination tensor followed by an update halo // operation. -static std::optional, MeshSharding>> -tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional, Sharding>> +tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { // Currently handles only cases where halo sizes differ but everything else @@ -392,7 +384,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) && ShapedType::isStaticShape(tgtHaloSizes) && sourceShard.getType().hasStaticShape()) && - "dynamic shapes/halos are not supported yet for mesh-spmdization"); + "dynamic shapes/halos are not supported yet for shard-partition"); auto rank = sourceShard.getType().getRank(); auto splitAxes = sourceSharding.getSplitAxes(); SmallVector srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), @@ -433,8 +425,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, sourceShard.getLoc(), RankedTensorType::get(outShape, sourceShard.getType().getElementType()), - initOprnd, mesh.getSymName(), - MeshAxesArrayAttr::get(builder.getContext(), + initOprnd, grid.getSymName(), + GridAxesArrayAttr::get(builder.getContext(), sourceSharding.getSplitAxes()), targetSharding.getDynamicHaloSizes(), targetSharding.getStaticHaloSizes()) @@ -443,41 +435,41 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, targetSharding); } -// Handles only resharding on a 1D mesh. +// Handles only resharding on a 1D shard. // Currently the sharded tensor axes must be exactly divisible by the single -// mesh axis size. +// grid axis size. static TypedValue -reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, MeshSharding targetSharding, +reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { assert(sourceShard.getType() == - shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); + shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding)); [[maybe_unused]] ShapedType targetShardType = - shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); + shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding); assert(sourceShard.getType().getRank() == targetShardType.getRank()); - assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); + assert(grid.getRank() == 1 && "Only 1D grides are currently supported."); if (sourceSharding == targetSharding) { return sourceShard; } TypedValue targetShard; - MeshSharding actualTargetSharding; + Sharding actualTargetSharding; if (sourceSharding.getStaticShardedDimsOffsets().empty() && targetSharding.getStaticShardedDimsOffsets().empty() && sourceSharding.getStaticHaloSizes().empty() && targetSharding.getStaticHaloSizes().empty()) { if (auto tryRes = tryMoveLastSplitAxisInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = - trySplitLastAxisInResharding(builder, mesh, sourceSharding, + trySplitLastAxisInResharding(builder, grid, sourceSharding, targetSharding, sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = tryUnsplitLastAxisInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } @@ -488,9 +480,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, return targetShard; } -TypedValue reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +TypedValue reshard(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { // If source and destination sharding are the same, no need to do anything. @@ -500,28 +491,28 @@ TypedValue reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, } // Tries to handle the case where the resharding is needed because the halo - // sizes are different. Supports arbitrary mesh dimensionality. + // sizes are different. Supports arbitrary grid dimensionality. if (auto tryRes = tryUpdateHaloInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { return std::get<0>(tryRes.value()); // targetShard } - // Resort to handling only 1D meshes since the general case is complicated if + // Resort to handling only 1D grids since the general case is complicated if // it needs to be communication efficient in terms of minimizing the data // transfered between devices. - return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, + return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding, sourceUnshardedValue, sourceShard); } -TypedValue reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, +TypedValue reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue sourceShardValue) { assert(source.getResult() == target.getSrc()); auto sourceSharding = source.getSharding(); auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); - return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, + return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, cast>(source.getSrc()), sourceShardValue); } @@ -530,21 +521,21 @@ TypedValue reshard(OpBuilder &builder, ShardOp source, ShardOp target, TypedValue sourceShardValue, SymbolTableCollection &symbolTableCollection) { - MeshOp srcMesh = getMesh(source, symbolTableCollection); - assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); - return reshard(builder, srcMesh, source, target, sourceShardValue); + GridOp srcGrid = getGrid(source, symbolTableCollection); + assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection)); + return reshard(builder, srcGrid, source, target, sourceShardValue); } void reshardingRegisterDependentDialects(DialectRegistry ®istry) { - registry.insert(); + registry.insert(); } -#define GEN_PASS_DEF_SPMDIZATION -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" +#define GEN_PASS_DEF_PARTITION +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" using UnshardedToShardedValueMap = DenseMap; -// Get the types of block arguments for an spmdized block. +// Get the types of block arguments for an partitioned block. // Reads the sharding annotations of the arguments to deduce the sharded types. // Types that are not ranked tensors are left unchanged. SmallVector @@ -563,35 +554,36 @@ shardedBlockArgumentTypes(Block &block, Operation *useOp = *rankedTensorArg.getUsers().begin(); ShardOp shardOp = llvm::dyn_cast(useOp); assert(shardOp); - MeshOp mesh = getMesh(shardOp, symbolTableCollection); - return cast(shardShapedType(rankedTensorArg.getType(), mesh, + GridOp grid = getGrid(shardOp, symbolTableCollection); + return cast(shardShapedType(rankedTensorArg.getType(), grid, shardOp.getSharding())); }); return res; } -static LogicalResult spmdizeOperation( - Operation &op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { +static LogicalResult +partitionOperation(Operation &op, ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { ShardingInterface shardingInterface = llvm::dyn_cast(op); if (!shardingInterface) { // If there is no sharding interface we are conservative and assume that // the op should be fully replicated no all devices. - spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTableCollection, builder); + partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings, + resultShardings, partitionMap, + symbolTableCollection, builder); } else { - if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTableCollection, builder))) { + if (failed(shardingInterface.partition( + partitionedOperands, operandShardings, resultShardings, + partitionMap, symbolTableCollection, builder))) { return failure(); } } - assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { - return spmdizationMap.contains(result); + assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) { + return partitionMap.contains(result); })); return success(); @@ -599,87 +591,87 @@ static LogicalResult spmdizeOperation( // Retrieve the sharding annotations for the operands of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. -static std::vector getOperandShardings(Operation &op) { - std::vector res; +static std::vector getOperandShardings(Operation &op) { + std::vector res; res.reserve(op.getNumOperands()); llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { TypedValue rankedTensor = dyn_cast>(operand); if (!rankedTensor || rankedTensor.getType().getRank() == 0) { - return MeshSharding(); + return Sharding(); } Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast(definingOp); - return MeshSharding(shardOp.getSharding()); + return Sharding(shardOp.getSharding()); }); return res; } // Retrieve the sharding annotations for the results of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. -static std::vector getResultShardings(Operation &op) { - std::vector res; +static std::vector getResultShardings(Operation &op) { + std::vector res; res.reserve(op.getNumResults()); llvm::transform( op.getResults(), std::back_inserter(res), [&op](OpResult result) { if (!result.hasOneUse() || result.use_empty()) { - return MeshSharding(); + return Sharding(); } TypedValue rankedTensor = dyn_cast>(result); if (!rankedTensor) { - return MeshSharding(); + return Sharding(); } Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::dyn_cast(userOp); if (shardOp) { - return MeshSharding(shardOp.getSharding()); + return Sharding(shardOp.getSharding()); } if (rankedTensor.getType().getRank() == 0) { // This is a 0d tensor result without explicit sharding. - // Find mesh symbol from operands, if any. - // Shardings without mesh are not always fully supported yet. + // Find grid symbol from operands, if any. + // Shardings without grid are not always fully supported yet. for (auto operand : op.getOperands()) { if (auto sharding = operand.getDefiningOp()) { - return MeshSharding(sharding.getMeshAttr()); + return Sharding(sharding.getGridAttr()); } } } - return MeshSharding(); + return Sharding(); }); return res; } static LogicalResult -spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { - Value targetSpmdValue; +partitionOperation(ShardOp shardOp, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + Value targetPartitionValue; // Check if 2 shard ops are chained. If not there is no need for resharding // as the source and target shared the same sharding. ShardOp srcShardOp = shardOp.getSrc().getDefiningOp(); if (!srcShardOp) { - targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); + targetPartitionValue = partitionMap.lookup(shardOp.getSrc()); } else { // Insert resharding. - TypedValue srcSpmdValue = - cast>(spmdizationMap.lookup(srcShardOp)); - targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); + TypedValue srcPartitionValue = + cast>(partitionMap.lookup(srcShardOp)); + targetPartitionValue = reshard(builder, srcShardOp, shardOp, + srcPartitionValue, symbolTableCollection); } - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + assert(!partitionMap.contains(shardOp.getResult())); + partitionMap.map(shardOp.getResult(), targetPartitionValue); return success(); } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { +partitionOperation(Operation &op, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { if (isa(op)) { return success(); } @@ -689,30 +681,31 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, return op.emitError("expected a shard op as source of get_sharding"); } auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); - spmdizationMap.map(op.getResult(0), newSharding->getResult(0)); + partitionMap.map(op.getResult(0), newSharding->getResult(0)); return success(); } ShardOp shardOp = llvm::dyn_cast(op); if (shardOp) { - return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, - builder); + return partitionOperation(shardOp, partitionMap, symbolTableCollection, + builder); } - SmallVector spmdizedOperands; - llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), - [&spmdizationMap](Value operand) { - assert(spmdizationMap.contains(operand)); - return spmdizationMap.lookup(operand); + SmallVector partitionedOperands; + llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands), + [&partitionMap](Value operand) { + assert(partitionMap.contains(operand)); + return partitionMap.lookup(operand); }); - return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), - getResultShardings(op), spmdizationMap, - symbolTableCollection, builder); + return partitionOperation(op, partitionedOperands, getOperandShardings(op), + getResultShardings(op), partitionMap, + symbolTableCollection, builder); } -static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { +static LogicalResult +partitionBlock(Block &block, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { SmallVector argLocations; llvm::transform(block.getArguments(), std::back_inserter(argLocations), @@ -720,16 +713,16 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, Block *newBlock = builder.createBlock( block.getParent(), {}, shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); - for (auto [unshardedBlockArg, spmdizedBlockArg] : + for (auto [unshardedBlockArg, partitionedBlockArg] : llvm::zip(block.getArguments(), newBlock->getArguments())) { - spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); + partitionMap.map(unshardedBlockArg, partitionedBlockArg); } OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(newBlock); for (Operation &op : block.getOperations()) { - if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, - builder))) { + if (failed(partitionOperation(op, partitionMap, symbolTableCollection, + builder))) { return failure(); } } @@ -738,8 +731,8 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, } static LogicalResult -spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection) { +partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection) { OpBuilder builder(op.getFunctionBody()); // Snapshot the original blocks to not mess up the iteration when adding new @@ -753,8 +746,8 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, } for (Block *block : originalBlocks) { - if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, - builder))) { + if (failed(partitionBlock(*block, partitionMap, symbolTableCollection, + builder))) { return failure(); } } @@ -787,22 +780,22 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, namespace { -struct Spmdization : public impl::SpmdizationBase { +struct Partition : public impl::PartitionBase { void runOnOperation() override { - IRMapping spmdizationMap; + IRMapping partitionMap; SymbolTableCollection symbolTableCollection; - if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, - symbolTableCollection))) { + if (failed(partitionFuncOp(getOperation(), partitionMap, + symbolTableCollection))) { return signalPassFailure(); } } void getDependentDialects(DialectRegistry ®istry) const override { reshardingRegisterDependentDialects(registry); - registry.insert(); + registry.insert(); } }; } // namespace -} // namespace mlir::mesh +} // namespace mlir::shard diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp similarity index 85% rename from mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp rename to mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index 0a683768e078d..a647128cf0500 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/STLExtras.h" @@ -21,17 +21,17 @@ #include namespace mlir { -namespace mesh { +namespace shard { #define GEN_PASS_DEF_SHARDINGPROPAGATION -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" -} // namespace mesh +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" +} // namespace shard } // namespace mlir #define DEBUG_TYPE "sharding-propagation" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; enum class ReshardingRquirementKind { NO_RESHARDING = 0, @@ -68,7 +68,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, const ShardingOption &v) { - return stream << "{empty = " << v.empty << ", mesh" << v.mesh + return stream << "{empty = " << v.empty << ", grid" << v.grid << ", shardingArray = " << v.shardingArray << "}"; } @@ -105,15 +105,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) { // specific shardings. For example, mustShardings = [shard0, None] and // optionalShardings = [None, shard1], the result will be [[shard0, shard1], // [shard0, None]] -static SmallVector> -getOrderedPossibleShardingAttrs(ArrayRef mustShardings, - ArrayRef optionalShardings) { - SmallVector> allShardingAttrs; - std::vector curShardingAttrs; +static SmallVector> +getOrderedPossibleShardingAttrs(ArrayRef mustShardings, + ArrayRef optionalShardings) { + SmallVector> allShardingAttrs; + std::vector curShardingAttrs; std::function dfsCreateShardingAttrs = [&](size_t i) { if (i == mustShardings.size()) { - allShardingAttrs.push_back(std::vector(curShardingAttrs)); + allShardingAttrs.push_back(std::vector(curShardingAttrs)); return; } @@ -147,14 +147,14 @@ getOrderedPossibleShardingAttrs(ArrayRef mustShardings, // 1. No resharding is required (all existing annotations are compatible). // 2. No resharding for operands/results that have annotation specifically // targeting this operation. This means -// * operands that are the result of `mesh.shard` ops marked with +// * operands that are the result of `shard.shard` ops marked with // `annotate_for_users`. -// * results that are annotated with `mesh.shard` ops without +// * results that are annotated with `shard.shard` ops without // `annotate_for_users`. // 3. All other cases. Resharding is required for operands/results with // annotation targeting explicitly this operation. ReshardingRquirementKind getReshardingRquirementKind( - Operation *op, const std::vector &operandAndResultShardings) { + Operation *op, const std::vector &operandAndResultShardings) { ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING; size_t operandsCount = op->getOperands().size(); @@ -213,14 +213,13 @@ ReshardingRquirementKind getReshardingRquirementKind( // 3. Resharding of existing explicit sharding annotations for this op. static FailureOr selectShardingOption( ShardingInterface shardingOp, - ArrayRef> possibleOperandShardingAttrs, - ArrayRef> possibleResultShardingAttrs) { + ArrayRef> possibleOperandShardingAttrs, + ArrayRef> possibleResultShardingAttrs) { SmallVector> shardingOptionsAndReshardingRequirements; - for (ArrayRef resultShardings : possibleResultShardingAttrs) { - for (ArrayRef operandShardings : - possibleOperandShardingAttrs) { + for (ArrayRef resultShardings : possibleResultShardingAttrs) { + for (ArrayRef operandShardings : possibleOperandShardingAttrs) { FailureOr shardingOption = shardingOp.getShardingOption(operandShardings, resultShardings); if (failed(shardingOption) || shardingOption->empty) { @@ -231,7 +230,7 @@ static FailureOr selectShardingOption( // They may be missing some annotations. // Whatever is returned by getShardingAnnotations is exactly what the op // needs. - FailureOr> operandAndResultShardings = + FailureOr> operandAndResultShardings = shardingOp.getShardingAnnotations(*shardingOption); if (failed(operandAndResultShardings)) { return failure(); @@ -276,13 +275,13 @@ static FailureOr selectShardingOption( // For each operation that implements the ShardingInterface, infer the sharding // option of the operation from its operands and/or results using the // `getShardingOption` method. If the inferred sharding option is not empty, add -// a `mesh.shard` operation for all remaining operands and results that do not +// a `shard.shard` operation for all remaining operands and results that do not // have sharding annotations. static LogicalResult visitOp(Operation *op, OpBuilder &builder) { ShardingInterface shardingOp = llvm::dyn_cast(op); if (op->hasTrait() || (op->hasTrait() && !shardingOp) || - llvm::isa(op)) + llvm::isa(op)) return success(); if (!shardingOp) { @@ -290,14 +289,13 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { return failure(); } - // collect MeshSharding from results - std::vector allowConflictsResultShardings; + // collect Sharding from results + std::vector allowConflictsResultShardings; allowConflictsResultShardings.resize(op->getNumResults()); - std::vector resultMustShardings; + std::vector resultMustShardings; resultMustShardings.resize(op->getNumResults()); for (OpResult result : op->getResults()) { - FailureOr> maybeShardAttr = - getMeshSharding(result); + FailureOr> maybeShardAttr = getSharding(result); if (failed(maybeShardAttr)) continue; if (!maybeShardAttr->first) @@ -307,14 +305,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { maybeShardAttr->second; } - // collect MeshSharding from operands - std::vector allowConflictsOperandShardings; + // collect Sharding from operands + std::vector allowConflictsOperandShardings; allowConflictsOperandShardings.resize(op->getNumOperands()); - std::vector operandMustShardings; + std::vector operandMustShardings; operandMustShardings.resize(op->getNumOperands()); for (OpOperand &opOperand : op->getOpOperands()) { - FailureOr> maybeShardAttr = - getMeshSharding(opOperand); + FailureOr> maybeShardAttr = + getSharding(opOperand); if (failed(maybeShardAttr)) continue; @@ -327,10 +325,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { } // try to get the sharding option - SmallVector> possibleOperandShardingAttrs = + SmallVector> possibleOperandShardingAttrs = getOrderedPossibleShardingAttrs(operandMustShardings, allowConflictsOperandShardings); - SmallVector> possibleResultShardingAttrs = + SmallVector> possibleResultShardingAttrs = getOrderedPossibleShardingAttrs(resultMustShardings, allowConflictsResultShardings); FailureOr shardingOption = selectShardingOption( @@ -358,7 +356,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { // ShardingPropagation //===----------------------------------------------------------------------===// struct ShardingPropagation - : public mesh::impl::ShardingPropagationBase { + : public shard::impl::ShardingPropagationBase { using ShardingPropagationBase::ShardingPropagationBase; diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp similarity index 66% rename from mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp rename to mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp index 1315502801d72..a17671e5408c4 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp @@ -1,4 +1,4 @@ -//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===// +//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" #include "TransformsDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" @@ -18,7 +18,7 @@ #include namespace mlir { -namespace mesh { +namespace shard { void populateSimplificationPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { @@ -52,53 +52,53 @@ namespace { // DialectFoldInterface, because it needs a SymbolTableCollection to cache the // symbol tables. // We can't use DialectFoldInterface since the cache may be invalidated by some -// pass changing the referenced MeshOp ops. -struct MeshShapeFolder - : OpRewritePatternWithSymbolTableCollection { +// pass changing the referenced GridOp ops. +struct GridShapeFolder + : OpRewritePatternWithSymbolTableCollection { using OpRewritePatternWithSymbolTableCollection:: OpRewritePatternWithSymbolTableCollection; - LogicalResult matchAndRewrite(MeshShapeOp op, + LogicalResult matchAndRewrite(GridShapeOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom( - op.getOperation(), op.getMeshAttr()); - if (!mesh) { + GridOp grid = symbolTableCollection.lookupNearestSymbolFrom( + op.getOperation(), op.getGridAttr()); + if (!grid) { return failure(); } - ArrayRef opMeshAxes = op.getAxes(); - SmallVector opAxesIota; - if (opMeshAxes.empty()) { - opAxesIota.resize(mesh.getRank()); + ArrayRef opGridAxes = op.getAxes(); + SmallVector opAxesIota; + if (opGridAxes.empty()) { + opAxesIota.resize(grid.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); - opMeshAxes = opAxesIota; + opGridAxes = opAxesIota; } - if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) { - return ShapedType::isDynamic(mesh.getShape()[axis]); + if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) { + return ShapedType::isDynamic(grid.getShape()[axis]); })) { - // All mesh dimensions are dynamic. Nothing to fold. + // All grid dimensions are dynamic. Nothing to fold. return failure(); } SmallVector newResults(op->getResults().size()); - SmallVector newShapeOpMeshAxes; + SmallVector newShapeOpGridAxes; SmallVector newToOldResultsIndexMap; - for (size_t i = 0; i < opMeshAxes.size(); ++i) { - auto meshAxisSize = mesh.getShape()[opMeshAxes[i]]; - if (ShapedType::isDynamic(meshAxisSize)) { + for (size_t i = 0; i < opGridAxes.size(); ++i) { + auto gridAxisSize = grid.getShape()[opGridAxes[i]]; + if (ShapedType::isDynamic(gridAxisSize)) { newToOldResultsIndexMap.push_back(i); - newShapeOpMeshAxes.push_back(opMeshAxes[i]); + newShapeOpGridAxes.push_back(opGridAxes[i]); } else { - // Fold static mesh axes. + // Fold static grid axes. newResults[i] = arith::ConstantOp::create( - builder, builder.getIndexAttr(meshAxisSize)); + builder, builder.getIndexAttr(gridAxisSize)); } } - // Leave only the dynamic mesh axes to be queried. - if (!newShapeOpMeshAxes.empty()) { - MeshShapeOp newShapeOp = - MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes); + // Leave only the dynamic grid axes to be queried. + if (!newShapeOpGridAxes.empty()) { + GridShapeOp newShapeOp = + GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes); for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; } @@ -113,8 +113,8 @@ struct MeshShapeFolder void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { - patterns.add(symbolTableCollection, patterns.getContext()); + patterns.add(symbolTableCollection, patterns.getContext()); } -} // namespace mesh +} // namespace shard } // namespace mlir diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp similarity index 78% rename from mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp rename to mlir/lib/Dialect/Shard/Transforms/Transforms.cpp index 1bde1af28d8c3..772e66fee5c56 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "TransformsDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" @@ -14,8 +14,8 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" @@ -28,12 +28,12 @@ #include #include -namespace mlir::mesh { +namespace mlir::shard { namespace { -/// Lower `mesh.process_multi_index` into expression using -/// `mesh.process_linear_index` and `mesh.mesh_shape`. +/// Lower `shard.process_multi_index` into expression using +/// `shard.process_linear_index` and `shard.grid_shape`. struct ProcessMultiIndexOpLowering : OpRewritePatternWithSymbolTableCollection { using OpRewritePatternWithSymbolTableCollection:: @@ -41,30 +41,30 @@ struct ProcessMultiIndexOpLowering LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { - MeshOp mesh = getMesh(op, symbolTableCollection); - if (!mesh) { + GridOp grid = getGrid(op, symbolTableCollection); + if (!grid) { return failure(); } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value linearIndex = ProcessLinearIndexOp::create(builder, mesh); - ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults(); + Value linearIndex = ProcessLinearIndexOp::create(builder, grid); + ValueRange gridShape = GridShapeOp::create(builder, grid).getResults(); SmallVector completeMultiIndex = affine::AffineDelinearizeIndexOp::create(builder, linearIndex, - meshShape) + gridShape) .getMultiIndex(); SmallVector multiIndex; - ArrayRef opMeshAxes = op.getAxes(); - SmallVector opAxesIota; - if (opMeshAxes.empty()) { - opAxesIota.resize(mesh.getRank()); + ArrayRef opGridAxes = op.getAxes(); + SmallVector opAxesIota; + if (opGridAxes.empty()) { + opAxesIota.resize(grid.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); - opMeshAxes = opAxesIota; + opGridAxes = opAxesIota; } - llvm::transform(opMeshAxes, std::back_inserter(multiIndex), - [&completeMultiIndex](MeshAxis meshAxis) { - return completeMultiIndex[meshAxis]; + llvm::transform(opGridAxes, std::back_inserter(multiIndex), + [&completeMultiIndex](GridAxis gridAxis) { + return completeMultiIndex[gridAxis]; }); rewriter.replaceAllUsesWith(op.getResults(), multiIndex); return success(); @@ -86,15 +86,15 @@ struct AllSliceOpLowering // axis. // The slice axis is split into equisized parts with count // the number of processes in the collective process group induced by - // the mesh axes. + // the grid axes. // The part for each process is determined by the corresponding // linear-index in the process group. // // There are no collectives that require communication. // Each process operates on its local tensor. - MeshOp mesh = getMesh(op, symbolTableCollection); - if (!mesh) { + GridOp grid = getGrid(op, symbolTableCollection); + if (!grid) { return failure(); } @@ -104,15 +104,15 @@ struct AllSliceOpLowering Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); Operation::result_range processInGroupMultiIndex = - ProcessMultiIndexOp::create(builder, mesh.getSymName(), - op.getMeshAxes()) + ProcessMultiIndexOp::create(builder, grid.getSymName(), + op.getGridAxes()) .getResults(); Operation::result_range processGroupShape = - MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes()) + GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes()) .getResult(); Value processGroupSize = - createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); + createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder); int64_t sliceAxis = op.getSliceAxis().getSExtValue(); Value operandSliceAxisSize = @@ -125,7 +125,7 @@ struct AllSliceOpLowering cf::AssertOp::create(builder, isTargetShapeExactlyDivisible, "Slicing a tensor with axis size that is " "not exactly divisible by the " - "mesh process group size is not supported."); + "grid process group size is not supported."); Value resultSliceAxisSize = arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( @@ -172,7 +172,7 @@ void populateProcessMultiIndexOpLoweringPatterns( } void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { - registry.insert(); + registry.insert(); } void populateAllSliceOpLoweringPatterns( @@ -183,7 +183,7 @@ void populateAllSliceOpLoweringPatterns( void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { registry.insert(); } @@ -199,21 +199,21 @@ void registerAllOpLoweringDialects(DialectRegistry ®istry) { } TypedValue -createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, +createCollectiveProcessGroupSize(GridOp grid, ArrayRef axes, ImplicitLocOpBuilder &builder) { - Operation::result_range meshShape = - mesh::MeshShapeOp::create(builder, mesh, axes).getResults(); + Operation::result_range gridShape = + GridShapeOp::create(builder, grid, axes).getResults(); return cast>(arith::createProduct( - builder, builder.getLoc(), llvm::to_vector_of(meshShape), + builder, builder.getLoc(), llvm::to_vector_of(gridShape), builder.getIndexType())); } TypedValue -createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, - ArrayRef meshAxes, +createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, + ArrayRef gridAxes, ImplicitLocOpBuilder &builder) { Operation::result_range processGroupShape = - MeshShapeOp::create(builder, mesh, meshAxes).getResult(); + GridShapeOp::create(builder, grid, gridAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); @@ -225,11 +225,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, return cast>(res); } -TypedValue createProcessLinearIndex(StringRef mesh, - ArrayRef meshAxes, +TypedValue createProcessLinearIndex(StringRef grid, + ArrayRef gridAxes, ImplicitLocOpBuilder &builder) { return createProcessLinearIndex( - mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(), - meshAxes, builder); + grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(), + gridAxes, builder); } -} // namespace mlir::mesh +} // namespace mlir::shard diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h similarity index 82% rename from mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h rename to mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h index 3e3f584caca24..60c9828ba736d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h +++ b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H -#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" namespace mlir { -namespace mesh { +namespace shard { template struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern { @@ -29,7 +29,7 @@ struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern { SymbolTableCollection &symbolTableCollection; }; -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp index 0421a6c0ff806..0784615b8edb8 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" -#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt index dba59333666f6..8f0b7da1fd7b5 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt @@ -1,10 +1,10 @@ set(LLVM_OPTIONAL_SOURCES AllExtensions.cpp - MeshShardingExtensions.cpp + ShardingExtensions.cpp ) -add_mlir_extension_library(MLIRTensorMeshShardingExtensions - MeshShardingExtensions.cpp +add_mlir_extension_library(MLIRTensorShardingExtensions + ShardingExtensions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions @@ -22,5 +22,5 @@ add_mlir_extension_library(MLIRTensorAllExtensions ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions LINK_LIBS PUBLIC - MLIRTensorMeshShardingExtensions + MLIRTensorShardingExtensions ) \ No newline at end of file diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp similarity index 74% rename from mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp rename to mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp index 7e4a5acb9867d..ca7287cec55ce 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" using namespace mlir; using namespace mlir::tensor; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { @@ -40,20 +40,20 @@ struct CreatorOpShardingInterface {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef partitionedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { assert(resultShardings.size() == 1); auto resType = cast(op->getResult(0).getType()); - mlir::mesh::MeshOp mesh; + mlir::shard::GridOp grid; ShapedType shardType; if (resType.getRank() > 0) { - mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable); + grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable); shardType = - cast(mesh::shardType(resType, mesh, resultShardings[0])); + cast(shard::shardType(resType, grid, resultShardings[0])); } else { shardType = resType; } @@ -67,7 +67,7 @@ struct CreatorOpShardingInterface auto oldType = cast(resType); assert(oldType.getRank() == shardType.getRank()); int currOldOprndNum = -1; - mesh::ShardShapeOp shapeForDevice; + shard::ShardShapeOp shapeForDevice; ValueRange device; Operation *newSharding = nullptr; for (auto i = 0; i < oldType.getRank(); ++i) { @@ -76,23 +76,23 @@ struct CreatorOpShardingInterface newSharding = ShardingOp::create(builder, op->getLoc(), resultShardings[0]); device = - mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh) + shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid) .getResults(); - shapeForDevice = mesh::ShardShapeOp::create( - builder, op->getLoc(), oldType.getShape(), spmdizedOperands, + shapeForDevice = shard::ShardShapeOp::create( + builder, op->getLoc(), oldType.getShape(), partitionedOperands, newSharding->getResult(0), device); } newOperands.emplace_back(shapeForDevice.getResult()[i]); } else if (oldType.isDynamicDim(i)) { assert(shardType.isDynamicDim(i)); - newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); + newOperands.emplace_back(partitionedOperands[++currOldOprndNum]); } } newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands); - spmdizationMap.map(op->getResult(0), newOp->getResult(0)); + partitionMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. - newOp = builder.clone(*op, spmdizationMap); + newOp = builder.clone(*op, partitionMap); } newOp->getResult(0).setType(shardType); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index b1fac8c85a204..c6a438d348946 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl LINK_LIBS PUBLIC MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRShardingInterface MLIRSupport MLIRTosaDialect diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index d3a5f44798106..45994a7ec679f 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/DialectRegistry.h" @@ -19,7 +19,7 @@ using namespace mlir; using namespace mlir::tosa; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { @@ -87,15 +87,15 @@ struct NegateOpSharding return maps; } - LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, - ArrayRef operandShardings, - ArrayRef resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTable, builder); + LogicalResult partition(Operation *op, ArrayRef partitiondOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + partitionTriviallyShardableOperation(*op, partitiondOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 648e508a9788f..ecd93ff4c6e7b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" @@ -166,7 +166,7 @@ void TosaDialect::initialize() { >(); addInterfaces(); declarePromisedInterfaces< - mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, + shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir similarity index 90% rename from mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir rename to mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index d54d0034da5be..5e20b5a59d927 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -1,14 +1,14 @@ -// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-shard-to-mpi -canonicalize -split-input-file | FileCheck %s // ----- -// CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 3x4x5) +// CHECK: shard.grid @grid0 +shard.grid @grid0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { // CHECK: mpi.comm_rank // CHECK-DAG: %[[v4:.*]] = arith.remsi // CHECK-DAG: %[[v0:.*]] = arith.remsi // CHECK-DAG: %[[v1:.*]] = arith.remsi - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } @@ -17,7 +17,7 @@ func.func @process_multi_index() -> (index, index, index) { func.func @process_linear_index() -> index { // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index - %0 = mesh.process_linear_index on @mesh0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[cast]] : index return %0 : index } @@ -29,7 +29,7 @@ func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [0] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } @@ -41,7 +41,7 @@ func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [1] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } @@ -53,20 +53,20 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [2] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } // ----- -// CHECK: mesh.mesh @mesh0 +// CHECK: shard.grid @grid0 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } @@ -74,7 +74,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { // CHECK: %[[c24:.*]] = arith.constant 24 : index - %0 = mesh.process_linear_index on @mesh0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[c24]] : index return %0 : index } @@ -82,7 +82,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func.func @allreduce_tensor( func.func @allreduce_tensor( // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> @@ -97,7 +97,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> // CHECK: return [[v2]] : tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -114,7 +114,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> // CHECK: return [[valloc]] : memref<3x4xf32> return %0 : memref<3x4xf32> } @@ -131,14 +131,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> // CHECK: return [[valloc]] : memref<3x4xf64> return %0 : memref<3x4xf64> } } // ----- -mesh.mesh @mesh0(shape = 3x4x5) +shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func @update_halo_1d_first func.func @update_halo_1d_first( // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8> @@ -155,14 +155,14 @@ func.func @update_halo_1d_first( // CHECK: mpi.recv( // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 // CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> // CHECK: return [[res:%.*]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { - mesh.mesh @mesh0(shape = 4) + shard.grid @grid0(shape = 4) // CHECK-LABEL: func @update_halo_1d_with_zero func.func @update_halo_1d_with_zero ( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> @@ -179,7 +179,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> // CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } @@ -187,7 +187,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func @update_halo_3d func.func @update_halo_3d( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> @@ -236,7 +236,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32 // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } @@ -291,18 +291,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32 // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> // CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> // CHECK: return [[v4]] : tensor<120x120x120xi8> return %res : tensor<120x120x120xi8> } } // ----- -mesh.mesh @mesh0(shape = 2x2x4) +shard.grid @grid0(shape = 2x2x4) // CHECK-LABEL: func.func @return_sharding( // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor, tensor, tensor) { -func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding +func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 @@ -316,13 +316,13 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh // CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor // CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor // CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor, tensor, tensor - return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding + return %arg0, %sharding : tensor<2x4xf32>, !shard.sharding } // CHECK-LABEL: func.func @return_sharding_halos( // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor, tensor, tensor) { -func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding +func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> @@ -336,13 +336,13 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m // CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor // CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor // CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor, tensor, tensor - return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding + return %arg0, %sharding : tensor<6x8xf32>, !shard.sharding } // CHECK-LABEL: func.func @return_sharding_offs( // CHECK-SAME: [[varg0:%.*]]: tensor) -> (tensor, tensor, tensor, tensor) { -func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding +func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 @@ -362,5 +362,5 @@ func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !me // CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor // CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor, tensor, tensor, tensor - return %arg0, %sharding : tensor, !mesh.sharding + return %arg0, %sharding : tensor, !shard.sharding } diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir similarity index 62% rename from mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir rename to mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir index 156bbfb54845b..9729d2bfb384e 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir @@ -1,21 +1,21 @@ -// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s +// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - // CHECK: mesh.mesh @mesh0 - mesh.mesh @mesh0(shape = 3x4x5) + // CHECK: shard.grid @grid0 + shard.grid @grid0(shape = 3x4x5) - // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0 + // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0 // all shards are equal // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) { func.func @shard_shape_equal() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -23,13 +23,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // last shard in last dim gets an extra element // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) { func.func @shard_shape_odd_1() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -37,11 +37,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // In the second dimension the shard sizes are now [3 4 4 4] // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) { func.func @shard_shape_odd_2() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index - %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -49,11 +49,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // In the first dimension the shard sizes are now [3 4 4] // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) { func.func @shard_shape_odd_3() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index - %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -61,14 +61,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // extract from sharded_dims_offsets // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] - sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] + sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index // CHECK: [[vc2:%.*]] = arith.constant 2 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir deleted file mode 100644 index 6b55dd533a92c..0000000000000 --- a/mlir/test/Dialect/Arith/mesh-spmdize.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh4x4(shape = 4x4) - -// CHECK-LABEL: func @test_spmdize_constant -// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : -// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : -// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> -func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { - %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> - %ci = arith.constant 434 : i32 - return %sharding_annotated_1 : tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir new file mode 100644 index 0000000000000..be894278e5e95 --- /dev/null +++ b/mlir/test/Dialect/Arith/shard-partition.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid4x4(shape = 4x4) + +// CHECK-LABEL: func @test_partition_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : +// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : +// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharded_1 : tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir index 19eb340549b0b..762620d9dae0c 100644 --- a/mlir/test/Dialect/Arith/sharding-propagation.mlir +++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir @@ -1,54 +1,54 @@ // RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s -mesh.mesh @mesh4x4(shape = 4x4) +shard.grid @grid4x4(shape = 4x4) // CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> -// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: return [[vsharded_8]] : tensor<1024x1024xf32> func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> %ci = arith.constant 43.4e+00 : f32 %o1 = tensor.empty() : tensor<1024x1024xf32> - %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %res = linalg.add ins(%sharded_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> return %res : tensor<1024x1024xf32> } // CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> %ci = arith.constant 43.4e+00 : f32 %o1 = tensor.empty() : tensor<1024x1024xf32> %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32> - return %sharding_annotated_1 : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %res to %sharding_1 : tensor<1024x1024xf32> + return %sharded_1 : tensor<1024x1024xf32> } diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir deleted file mode 100644 index 5297eeb666c1e..0000000000000 --- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --verify-each \ -// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_2(shape = 2) - -// CHECK-LABEL: func @matmul_shard_prallel_axis -func.func @matmul_shard_prallel_axis( - // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>, - %arg0 : tensor<2x3xf32>, - // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>, - %arg1 : tensor<3x2xf32>, - // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32> - %out_dps: tensor<2x2xf32> -) -> tensor<2x2xf32> { - // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32> - // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32> - // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding - // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32> - // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32> - %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> - - // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>) - // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32> - %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) - outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> - - // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32> - // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding - // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32> - %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding - %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32> - - // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32> - return %res_sharded : tensor<2x2xf32> -} diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir similarity index 50% rename from mlir/test/Dialect/Linalg/mesh-spmdization.mlir rename to mlir/test/Dialect/Linalg/shard-partition.mlir index ce12b296df1fa..aee97079fb197 100644 --- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir +++ b/mlir/test/Dialect/Linalg/shard-partition.mlir @@ -1,15 +1,15 @@ // RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ // RUN: --split-input-file \ // RUN: %s | FileCheck %s // CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)> #map_identity_1d = affine_map<(d0) -> (d0)> -mesh.mesh @mesh_1d(shape = 2) +shard.grid @grid_1d(shape = 2) -// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor -func.func @elementwise_static_1d_mesh_static_1d_tensor( +// CHECK-LABEL: func @elementwise_static_1d_grid_static_1d_tensor +func.func @elementwise_static_1d_grid_static_1d_tensor( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>, %in1: tensor<2xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>, @@ -18,13 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor( %dps_out: tensor<2xi8> // CHECK-SAME: -> tensor<1xi8> { ) -> tensor<2xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8> - %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8> - %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8> - %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8> - %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8> - %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in1_sharded1 = shard.shard %in1 to %sharding : tensor<2xi8> + %in1_sharded2 = shard.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %in2_sharded1 = shard.shard %in2 to %sharding : tensor<2xi8> + %in2_sharded2 = shard.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %dps_out_sharded1 = shard.shard %dps_out to %sharding : tensor<2xi8> + %dps_out_shared2 = shard.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8> // CHECK: %[[RES:.*]] = linalg.generic { // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]], // CHECK-SAME: iterator_types = ["parallel"]} @@ -39,18 +39,18 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor( %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8 linalg.yield %res_scalar : i8 } -> tensor<2xi8> - %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8> - %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %res_sharded1 = shard.shard %res to %sharding : tensor<2xi8> + %res_shared2 = shard.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8> // CHECK: return %[[RES]] : tensor<1xi8> return %res_shared2 : tensor<2xi8> } // ----- -mesh.mesh @mesh_1d(shape = 4) +shard.grid @grid_1d(shape = 4) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding -func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_sharding +func.func @matmul_1d_grid_static_tensors_parallel_iterator_sharding( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>, %in1: tensor<4x3xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>, @@ -59,32 +59,32 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<1x8xi8> { ) -> tensor<4x8xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8> - %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8> - %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8> - %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8> - %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8> + %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x3xi8> + %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<3x8xi8> + %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8> + %dps_out_shared1 = shard.shard %dps_out to %sharding : tensor<4x8xi8> + %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8> // CHECK: %[[RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>) // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>) // CHECK-SAME: -> tensor<1x8xi8> %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>) outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> - %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8> - %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8> + %res_shared1 = shard.shard %res to %sharding : tensor<4x8xi8> + %res_shared2 = shard.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8> // CHECK: return %[[RES]] : tensor<1x8xi8> return %res_shared2 : tensor<4x8xi8> } // ----- -mesh.mesh @mesh_1d(shape = 3) +shard.grid @grid_1d(shape = 3) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding -func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_reduction_iterator_sharding +func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, %in1: tensor<4x6xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, @@ -93,19 +93,19 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<4x8xi8> { ) -> tensor<4x8xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8> - %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8> - %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> - %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8> - %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> + %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x6xi8> + %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<6x8xi8> + %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> + %sharding3 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %dps_out_shared1 = shard.shard %dps_out to %sharding3 : tensor<4x8xi8> + %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 - // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK-DAG: %[[PROCESS_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index + // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> @@ -117,21 +117,21 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( // CHECK: } // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> - // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> + // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_MATMUL]] on @grid_1d grid_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> - %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8> - %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> + %res_shared1 = shard.shard %res to %sharding3 : tensor<4x8xi8> + %res_shared2 = shard.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8> return %res_shared2 : tensor<4x8xi8> } // ----- -mesh.mesh @mesh_1d(shape = 4) +shard.grid @grid_1d(shape = 4) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis -func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis +func.func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>, %in1: tensor<4x6xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>, @@ -140,25 +140,25 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<4x8xi8> { ) -> tensor<4x8xi8> { - %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding - %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8> - %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8> - // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8> - // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8> - %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8> + %sharding1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding + %in1_replicated1 = shard.shard %in1 to %sharding1 : tensor<4x6xi8> + %in1_replicated2 = shard.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8> + // CHECK: %[[ALL_SLICE1:.*]] = shard.all_slice %[[IN2]] on @grid_1d grid_axes = [0] slice_axis = 1 + %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x8xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8> + // CHECK: %[[ALL_SLICE2:.*]] = shard.all_slice %[[DPS_OUT]] on @grid_1d grid_axes = [0] slice_axis = 1 + %dps_out_replicated = shard.shard %dps_out to %sharding1 : tensor<4x8xi8> + %dps_out_sharded = shard.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8> // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>) // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>) // CHECK-SAME: -> tensor<4x2xi8> %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>) outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> - %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8> - %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[MATMUL_RES]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> + %res_sharded = shard.shard %res to %sharding2 : tensor<4x8xi8> + %res_replicated = shard.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8> // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8> return %res_replicated : tensor<4x8xi8> } diff --git a/mlir/test/Dialect/Linalg/sharding-propagation.mlir b/mlir/test/Dialect/Linalg/sharding-propagation.mlir new file mode 100644 index 0000000000000..e0ecefcf2d6bd --- /dev/null +++ b/mlir/test/Dialect/Linalg/sharding-propagation.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt \ +// RUN: --verify-each \ +// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_2(shape = 2) + +// CHECK-LABEL: func @matmul_shard_prallel_axis +func.func @matmul_shard_prallel_axis( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>, + %arg0 : tensor<2x3xf32>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>, + %arg1 : tensor<3x2xf32>, + // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32> + %out_dps: tensor<2x2xf32> +) -> tensor<2x2xf32> { + // CHECK: %[[SIN1_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = shard.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32> + // CHECK: %[[SIN1_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = shard.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32> + // CHECK: %[[SIN2_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding + // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = shard.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32> + // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = shard.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32> + %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> + + // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>) + // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32> + %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) + outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> + + // CHECK: %[[SRES_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = shard.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32> + // CHECK: %[[SRES_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding + // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = shard.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32> + %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding + %res_sharded = shard.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32> + + // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32> + return %res_sharded : tensor<2x2xf32> +} diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir deleted file mode 100644 index aff07bbf8a214..0000000000000 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ /dev/null @@ -1,248 +0,0 @@ -// RUN: mlir-opt --canonicalize %s | FileCheck %s - -mesh.mesh @mesh0(shape = 2x4) - -// CHECK-LABEL: func @all_reduce_empty_mesh_axes -func.func @all_reduce_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.all_reduce - %0 = mesh.all_reduce %arg0 on @mesh0 - mesh_axes = [] - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type -func.func @all_reduce_empty_mesh_axes_different_return_type( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { -// CHECK: mesh.all_reduce - %0 = mesh.all_reduce %arg0 on @mesh0 -// CHECK-NOT: mesh_axes - mesh_axes = [] - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @all_reduce_default_reduction -func.func @all_reduce_default_reduction( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { - %0 = mesh.all_reduce %arg0 on @mesh0 - mesh_axes = [0] -// CHECK-NOT: reduction - reduction = sum - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @all_to_all_empty_mesh_axes -func.func @all_to_all_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> - %arg0 : tensor<8xf32>) -> tensor<8xf32> { -// CHECK-NOT: mesh.all_to_all - %0 = mesh.all_to_all %arg0 on @mesh0 - mesh_axes = [] - split_axis = 0 - concat_axis = 0 - : tensor<8xf32> -> tensor<8xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<8xf32> -} - -// CHECK-LABEL: func @all_gather_empty_mesh_axes -func.func @all_gather_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.all_gather - %0 = mesh.all_gather %arg0 on @mesh0 - mesh_axes = [] - gather_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @all_slice_empty_mesh_axes -func.func @all_slice_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.scatter - %0 = mesh.all_slice %arg0 on @mesh0 - mesh_axes = [] - slice_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @broadcast_empty_mesh_axes -func.func @broadcast_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.broadcast - %0 = mesh.broadcast %arg0 on @mesh0 - mesh_axes = [] - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @gather_empty_mesh_axes -func.func @gather_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.gather - %0 = mesh.gather %arg0 on @mesh0 - mesh_axes = [] - gather_axis = 0 - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @receive_empty_mesh_axes -func.func @receive_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.recv - %0 = mesh.recv %arg0 on @mesh0 - mesh_axes = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_empty_mesh_axes -func.func @reduce_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.reduce - %0 = mesh.reduce %arg0 on @mesh0 - mesh_axes = [] - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes -func.func @reduce_scatter_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.reduce_scatter - %0 = mesh.reduce_scatter %arg0 on @mesh0 - mesh_axes = [] - scatter_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type -func.func @reduce_scatter_empty_mesh_axes_different_return_type( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { -// CHECK: mesh.reduce_scatter - %0 = mesh.reduce_scatter %arg0 on @mesh0 -// CHECK-NOT: mesh_axes - mesh_axes = [] - scatter_axis = 0 - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @reduce_scatter_default_reduction -func.func @reduce_scatter_default_reduction( - %arg0 : tensor<4xf32>) -> tensor<2xf64> { - %0 = mesh.reduce_scatter %arg0 on @mesh0 - mesh_axes = [0] -// CHECK-NOT: reduction - reduction = sum - scatter_axis = 0 - : tensor<4xf32> -> tensor<2xf64> - return %0 : tensor<2xf64> -} - -// CHECK-LABEL: func @scatter_empty_mesh_axes -func.func @scatter_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.scatter - %0 = mesh.scatter %arg0 on @mesh0 - mesh_axes = [] - scatter_axis = 0 - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @send_empty_mesh_axes -func.func @send_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.send - %0 = mesh.send %arg0 on @mesh0 - mesh_axes = [] - destination = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -mesh.mesh @mesh4x4(shape = 4x4) -// CHECK-LABEL: func @test_halo_sizes -func.func @test_halo_sizes() -> !mesh.sharding { - %c2_i64 = arith.constant 2 : i64 - // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding - return %sharding : !mesh.sharding -} - -// CHECK-LABEL: func @test_shard_offs -func.func @test_shard_offs() -> !mesh.sharding { - %c2_i64 = arith.constant 2 : i64 - // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding - return %sharding : !mesh.sharding -} - -// CHECK-LABEL: func @test_duplicate_shardops -func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> - // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> - return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> -} - -// CHECK-LABEL: func @test_duplicate_shardops_diff -func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding - %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> - %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32> - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> - // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> - return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir deleted file mode 100644 index 369f316d0f797..0000000000000 --- a/mlir/test/Dialect/Mesh/folding.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s - -mesh.mesh @mesh0(shape = 4x?x2) -mesh.mesh @mesh1(shape = 2x3) - -// CHECK-LABEL: func.func @mesh_shape_op_folding -func.func @mesh_shape_op_folding() -> (index, index) { - // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index - // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index - %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index - // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] - return %0#0, %0#1 : index, index -} - -// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh -func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) { - // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index - // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index - %0:2 = mesh.mesh_shape @mesh1 : index, index - // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] - return %0#0, %0#1 : index, index -} diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir deleted file mode 100644 index 6ab711b1b653c..0000000000000 --- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} - func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor) attributes {llvm.emit_c_interface} { - %c1_i32 = arith.constant 1 : i32 - // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> - %0 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[v1:%.*]] = linalg.fill ins - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32> - // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> - %3 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> - // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) { - // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> - %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated - : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { - ^bb0(%in: i32, %in_2: i32, %out: i32): - %9 = arith.addi %in, %in_2 : i32 - linalg.yield %9 : i32 - } -> tensor<6x6xi32> - %c0_i32 = arith.constant 0 : i32 - %6 = tensor.empty() : tensor - %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor) -> tensor - // CHECK: [[vreduced:%.*]] = linalg.reduce ins - // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding - // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor - %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor) dimensions = [0, 1] - (%in: i32, %init: i32) { - %9 = arith.addi %in, %init : i32 - linalg.yield %9 : i32 - } - // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding - %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding - // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor - %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor - return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor - } -} diff --git a/mlir/test/Dialect/Mesh/inlining.mlir b/mlir/test/Dialect/Mesh/inlining.mlir deleted file mode 100644 index c41a709e1a4eb..0000000000000 --- a/mlir/test/Dialect/Mesh/inlining.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt -inline %s | FileCheck %s - -mesh.mesh @mesh0(shape = 4x?x2) - -func.func private @mesh_to_inline() -> (index, index) { - %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index - return %0#0, %0#1 : index, index -} -// CHECK-LABEL: func.func @main -func.func @main() -> (index, index) { - // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index - %0:2 = func.call @mesh_to_inline() : () -> (index, index) - // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1 - return %0#0, %0#1 : index, index -} diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir deleted file mode 100644 index e23cfd79a4274..0000000000000 --- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s - -mesh.mesh @mesh2d(shape = ?x?) - -// CHECK-LABEL: func.func @multi_index_2d_mesh -func.func @multi_index_2d_mesh() -> (index, index) { - // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index - // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index - %0:2 = mesh.process_multi_index on @mesh2d : index, index - // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index - return %0#0, %0#1 : index, index -} - -// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis -func.func @multi_index_2d_mesh_single_inner_axis() -> index { - // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index - // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index - %0 = mesh.process_multi_index on @mesh2d axes = [0] : index - // CHECK: return %[[MULTI_IDX]]#0 : index - return %0 : index -} diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir deleted file mode 100644 index 5e62c929aa4ff..0000000000000 --- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir +++ /dev/null @@ -1,168 +0,0 @@ -// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s - -mesh.mesh @mesh_1d(shape = 2) -mesh.mesh @mesh_1d_dynamic(shape = ?) - -// CHECK-LABEL: func @same_source_and_target_sharding -func.func @same_source_and_target_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> - %arg0: tensor<2xf32> -) -> tensor<2xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32> - // CHECK: return %[[ARG]] - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: func @identical_source_and_target_sharding -func.func @identical_source_and_target_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> - %arg0: tensor<2xf32> -) -> tensor<2xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> - %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32> - // CHECK: return %[[ARG]] - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: func @split_replicated_tensor_axis -func.func @split_replicated_tensor_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> - %arg0: tensor<3x14xf32> -) -> tensor<3x14xf32> { - // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> - // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32> - // CHECK: return %[[RESULT]] : tensor<3x14xf32> - return %1 : tensor<3x14xf32> -} - -// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic -func.func @split_replicated_tensor_axis_dynamic( - // CHECK-SAME: %[[ARG:.*]]: tensor - %arg0: tensor -) -> tensor { - // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor -> tensor - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor - // CHECK: return %[[RESULT]] : tensor - return %1 : tensor -} - -// CHECK-LABEL: func @move_split_axis -func.func @move_split_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> - // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @move_split_axis_dynamic_mesh -func.func @move_split_axis_dynamic_mesh( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor - // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor -> tensor - // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor to tensor<10x?xf32> - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @move_split_dynamic_axis -func.func @move_split_dynamic_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor - %arg0: tensor -) -> tensor { - // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor -> tensor - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor to tensor - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor - // CHECK: return %[[RES]] : tensor - return %1 : tensor -} - -// CHECK-LABEL: func @unshard_static_axis -func.func @unshard_static_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @unshard_static_last_axis -func.func @unshard_static_last_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @unshard_dynamic_axis -func.func @unshard_dynamic_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor - %arg0: tensor -) -> tensor { - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor -> tensor - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor - // CHECK: return %[[ALL_GATHER]] : tensor - return %1 : tensor -} - -// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis -func.func @unshard_static_axis_on_dynamic_mesh_axis( -// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor -> tensor - // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir deleted file mode 100644 index 0881d994d60e7..0000000000000 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ /dev/null @@ -1,301 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s - -mesh.mesh @mesh_2(shape = 2) -mesh.mesh @mesh_1d(shape = ?) -mesh.mesh @mesh_2d(shape = 2x4) -mesh.mesh @mesh_3d(shape = ?x?x?) - -// CHECK-LABEL: func.func @element_wise_empty_sharding_info -func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: tosa.sigmoid - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: return - return %0 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_def -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V2]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_use -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: return %[[V2]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_graph_output -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_graph_input -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @arrow_structure -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> - %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] - // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> - %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]] - // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]] : tensor<8x16xf32> - %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> - %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V6]], %[[V8]] - return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n -// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> - // CHECK-NEXT: return [[vsharded_5]] - return %1 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k -// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding - %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32> - %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32> - // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> - // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding - // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32> - %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK: return [[vsharded_6]] - return %0 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> - // CHECK-NEXT: return %[[V3]] - return %2 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @resolve_conflicting_annotations -func.func @resolve_conflicting_annotations( - // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>, - %arg0: tensor<2x3xf32>, - // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>, - %arg1: tensor<3x2xf32>, - // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32> - %out_dps: tensor<2x2xf32> -// CHECK-SAME: ) -> tensor<2x2xf32> { -) -> tensor<2x2xf32> { - // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32> - // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32> - // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32> - // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32> - %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> - // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>) - // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32> - %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) - outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32> - %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding - %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32> - // CHECK: return %[[RES]] : tensor<2x2xf32> - return %res_sharded : tensor<2x2xf32> -} - -// https://arxiv.org/abs/2211.05102 Figure 2(a) -// The sharding propagation results in unnecessary reshards, -// an optimization pass should be able to remove them. -// CHECK-LABEL: func.func @mlp_1d_weight_stationary -// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> -func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> - %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32> - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> - // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32> - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32> - // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32> - %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32> - %sharded2 = mesh.shard %arg2 to %sharding : tensor<2x32x8xf32> - // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> - // CHECK: [[v2:%.*]] = tosa.matmul - %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> - // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> - %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32> - // CHECK: return [[vsharded_12]] - return %4 : tensor<2x4x8xf32> -} - -// https://arxiv.org/abs/2211.05102 Figure 2(b) -// The sharding propagation results in unnecessary reshards, -// an optimization pass should be able to remove them. -// CHECK-LABEL: func.func @mlp_2d_weight_stationary -// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> -func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding - %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> - %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding - %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32> - %arg1_s = mesh.shard %arg1 to %s1 : tensor<2x8x32xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32> - // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32> - %2 = mesh.shard %1 to %s0 : tensor<2x4x32xf32> - // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[v1:%.*]] = tosa.sigmoid - // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32> - %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding - %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding - // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32> - %arg2_s = mesh.shard %arg2 to %s2 : tensor<2x32x8xf32> - // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> - // CHECK: [[v2:%.*]] = tosa.matmul - %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> - // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> - %5 = mesh.shard %4 to %s0 : tensor<2x4x8xf32> - // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> - %6 = mesh.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32> - // CHECK: return [[vsharded_14]] - return %6 : tensor<2x4x8xf32> -} - -// CHECK-LABEL: func.func @elementwise_duplicated_chain -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]] : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding - %2 = mesh.shard %1 to %s0 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V5]] - return %2 : tensor<8x16xf32> -} diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir deleted file mode 100644 index 701898cbdc74d..0000000000000 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ /dev/null @@ -1,317 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_1d(shape = 2) - -// CHECK-LABEL: func @return_sharding -func.func @return_sharding( - // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> - %arg0: tensor<2xf32> -// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) { -) -> (tensor<2xf32>, !mesh.sharding) { - %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding - %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding - // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding - return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding -} - -// CHECK-LABEL: func @full_replication -func.func @full_replication( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<2xi8> { -) -> tensor<2xi8> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: return %[[ARG]] : tensor<2xi8> - return %1 : tensor<2xi8> -} - -// CHECK-LABEL: func @sharding_triplet -func.func @sharding_triplet( - // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32> - %arg0: tensor<2xf32> -// CHECK-SAME: ) -> tensor<2xf32> { -) -> tensor<2xf32> { - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32> - %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> - %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32> - %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<2xf32> - return %sharding_annotated_1 : tensor<2xf32> -} - - -// CHECK-LABEL: func @move_split_axis -func.func @move_split_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8> - %arg0: tensor<2x2xi8> -// CHECK-SAME: -> tensor<2x1xi8> { -) -> tensor<2x2xi8> { - // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8> - // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8> - return %1 : tensor<2x2xi8> -} - -// CHECK-LABEL: func @non_tensor_value -func.func @non_tensor_value( - // CHECK-SAME: %[[ARG:.*]]: i8 - %arg0: i8 -// CHECK-SAME: -> i8 { -) -> i8 { - // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8 - %0 = arith.addi %arg0, %arg0 : i8 - // CHECK: return %[[RES]] : i8 - return %0 : i8 -} - -// CHECK-LABEL: func @unary_elementwise -func.func @unary_elementwise( - // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<1xi8> - return %4 : tensor<2xi8> -} - -// full replication -> shard axis -> abs -> shard axis -> full replication -// CHECK-LABEL: func @unary_elementwise_with_resharding -func.func @unary_elementwise_with_resharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<2xi8> { -) -> tensor<2xi8> { - // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<2xi8> - return %4 : tensor<2xi8> -} - -// CHECK-LABEL: func @binary_elementwise -func.func @binary_elementwise( - // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>, - %arg0: tensor<2xi8>, - // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8> - %arg1: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8> - %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8> - %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8> - %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8> - // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> - %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> - %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8> - %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<1xi8> - return %res : tensor<2xi8> -} - -// reshard -// abs -// reshard -// abs -// reshard -// CHECK-LABEL: func @multiple_chained_ops -func.func @multiple_chained_ops( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> - %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %6 = mesh.shard %5 to %s6 : tensor<2xi8> - %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RESHARD3]] : tensor<1xi8> - return %7 : tensor<2xi8> -} - -// CHECK-LABEL: func @incomplete_sharding -func.func @incomplete_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> - %arg0: tensor<8x16xf32> -// CHECK-SAME: -> tensor<4x16xf32> { -) -> tensor<8x16xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> - // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %2 = mesh.shard %1 to %s2 : tensor<8x16xf32> - // CHECK: return %[[RES]] : tensor<4x16xf32> - return %2 : tensor<8x16xf32> -} - -mesh.mesh @mesh_1d_4(shape = 4) - -// CHECK-LABEL: func @ew_chain_with_halo -func.func @ew_chain_with_halo( - // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> - %arg0: tensor<8x16xf32>, - // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32> - %arg1: tensor<1xf32>, - // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32> - %arg2: tensor<1xf32>) - // CHECK-SAME: -> tensor<5x16xf32> - -> tensor<8x16xf32> { - %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32> - // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> - %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32> - %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32> - %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> - %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32> - %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32> - %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32> - %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding - %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32> - %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32> - %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> - %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32> - %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32> - return %sharding_annotated_6 : tensor<8x16xf32> -} - -// CHECK-LABEL: func @test_shard_update_halo -// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64> -func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding - // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> - // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> - %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> - %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> - %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> - // CHECK: return %[[UH]] : tensor<304x1200xi64> - return %sharding_annotated_3 : tensor<1200x1200xi64> -} - -mesh.mesh @mesh4x4(shape = 4x4) -// CHECK-LABEL: func @test_shard_update_halo2d -// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64> -func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding - // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> - // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> - %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> - %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> - %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> - // CHECK: return %[[UH]] : tensor<303x307xi64> - return %sharding_annotated_3 : tensor<1200x1200xi64> -} - -mesh.mesh @mesh(shape = 2) -// CHECK-LABEL: func.func @test_reduce_0d( -// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> -func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor) { - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> - %4 = tensor.empty() : tensor - %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding - %sharded_out = mesh.shard %4 to %sharding_out : tensor - %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> - // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) - %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor) dimensions = [0, 1] - (%in: i32, %init: i32) { - %6 = arith.addi %in, %init : i32 - linalg.yield %6 : i32 - } - // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor -> tensor - %sharded_red = mesh.shard %reduced to %sharding_out : tensor - %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor - // CHECK: return %[[all_reduce]] : tensor - return %sharded_ret : tensor -} - -// CHECK-LABEL: func.func @test_reduce_1d( -// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> -func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) { - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> - %4 = tensor.empty() : tensor<6xi32> - %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32> - %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> - // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) - %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] - (%in: i32, %init: i32) { - %6 = arith.addi %in, %init : i32 - linalg.yield %6 : i32 - } - // CHECK-NOT: mesh.all_reduce - %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32> - %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32> - // CHECK: return %[[reduced]] : tensor<3xi32> - return %sharded_ret : tensor<6xi32> -} diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir similarity index 72% rename from mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir rename to mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir index 4f54607a1c7ff..bc911215851aa 100644 --- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir +++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir @@ -1,43 +1,43 @@ -// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s -mesh.mesh @mesh_1d(shape = ?) +shard.grid @grid_1d(shape = ?) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh -func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid +func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid( // CHECK: %[[ARG:.*]]: tensor %arg0: tensor // CHECK-SAME: -> tensor { ) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK-DAG: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index + // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor - // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] - // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor to tensor - %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor -> tensor + %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor -> tensor // CHECK: return %[[RESULT]] : tensor return %0 : tensor } // ----- -mesh.mesh @mesh_1d(shape = 2) +shard.grid @grid_1d(shape = 2) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh -func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid +func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid( // CHECK: %[[ARG:.*]]: tensor<2xf16> %arg0: tensor<2xf16> // CHECK-SAME: -> tensor<1xf16> { ) -> tensor<1xf16> { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor to tensor<1xf16> - %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> + %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> // CHECK: return %[[RESULT]] : tensor<1xf16> return %0 : tensor<1xf16> } @@ -46,18 +46,18 @@ func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( // CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)> -mesh.mesh @mesh_4d(shape = ?x?x?x?) +shard.grid @grid_4d(shape = ?x?x?x?) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh -func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid +func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid( // CHECK: %[[ARG:.*]]: tensor %arg0 : tensor // CHECK-SAME: -> tensor { ) -> tensor { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index - // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index + // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index + // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index @@ -68,7 +68,7 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor to tensor - %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor -> tensor + %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor -> tensor // CHECK: return %[[RESULT]] : tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir similarity index 67% rename from mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir rename to mlir/test/Dialect/Shard/backward-sharding-propagation.mlir index 4223d01d65111..8894c4aee49c0 100644 --- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir +++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir @@ -2,17 +2,17 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + shard.grid @grid(shape = 1) {sym_visibility = "private"} func.func @test_forward() -> tensor<6x6xi32> { %c1_i32 = arith.constant 1 : i32 // CHECK: tensor.empty() %0 = tensor.empty() : tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - // CHECK-COUNT-2: mesh.shard - %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32> + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + // CHECK-COUNT-2: shard.shard + %sharded = shard.shard %0 to %sharding : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharded : tensor<6x6xi32>) -> tensor<6x6xi32> // CHECK: tensor.empty() - // CHECK-NOT: mesh.shard @ + // CHECK-NOT: shard.shard @ %2 = tensor.empty() : tensor<6x6xi32> %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1 : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir new file mode 100644 index 0000000000000..ed40dfb7237da --- /dev/null +++ b/mlir/test/Dialect/Shard/canonicalization.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt --canonicalize %s | FileCheck %s + +shard.grid @grid0(shape = 2x4) + +// CHECK-LABEL: func @all_reduce_empty_grid_axes +func.func @all_reduce_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.all_reduce + %0 = shard.all_reduce %arg0 on @grid0 + grid_axes = [] + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @all_reduce_empty_grid_axes_different_return_type +func.func @all_reduce_empty_grid_axes_different_return_type( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { +// CHECK: shard.all_reduce + %0 = shard.all_reduce %arg0 on @grid0 +// CHECK-NOT: grid_axes + grid_axes = [] + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @all_reduce_default_reduction +func.func @all_reduce_default_reduction( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { + %0 = shard.all_reduce %arg0 on @grid0 + grid_axes = [0] +// CHECK-NOT: reduction + reduction = sum + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @all_to_all_empty_grid_axes +func.func @all_to_all_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> + %arg0 : tensor<8xf32>) -> tensor<8xf32> { +// CHECK-NOT: shard.all_to_all + %0 = shard.all_to_all %arg0 on @grid0 + grid_axes = [] + split_axis = 0 + concat_axis = 0 + : tensor<8xf32> -> tensor<8xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<8xf32> +} + +// CHECK-LABEL: func @all_gather_empty_grid_axes +func.func @all_gather_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.all_gather + %0 = shard.all_gather %arg0 on @grid0 + grid_axes = [] + gather_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @all_slice_empty_grid_axes +func.func @all_slice_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.scatter + %0 = shard.all_slice %arg0 on @grid0 + grid_axes = [] + slice_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @broadcast_empty_grid_axes +func.func @broadcast_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.broadcast + %0 = shard.broadcast %arg0 on @grid0 + grid_axes = [] + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @gather_empty_grid_axes +func.func @gather_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.gather + %0 = shard.gather %arg0 on @grid0 + grid_axes = [] + gather_axis = 0 + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @receive_empty_grid_axes +func.func @receive_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.recv + %0 = shard.recv %arg0 on @grid0 + grid_axes = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_empty_grid_axes +func.func @reduce_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.reduce + %0 = shard.reduce %arg0 on @grid0 + grid_axes = [] + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_scatter_empty_grid_axes +func.func @reduce_scatter_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.reduce_scatter + %0 = shard.reduce_scatter %arg0 on @grid0 + grid_axes = [] + scatter_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_scatter_empty_grid_axes_different_return_type +func.func @reduce_scatter_empty_grid_axes_different_return_type( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { +// CHECK: shard.reduce_scatter + %0 = shard.reduce_scatter %arg0 on @grid0 +// CHECK-NOT: grid_axes + grid_axes = [] + scatter_axis = 0 + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @reduce_scatter_default_reduction +func.func @reduce_scatter_default_reduction( + %arg0 : tensor<4xf32>) -> tensor<2xf64> { + %0 = shard.reduce_scatter %arg0 on @grid0 + grid_axes = [0] +// CHECK-NOT: reduction + reduction = sum + scatter_axis = 0 + : tensor<4xf32> -> tensor<2xf64> + return %0 : tensor<2xf64> +} + +// CHECK-LABEL: func @scatter_empty_grid_axes +func.func @scatter_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.scatter + %0 = shard.scatter %arg0 on @grid0 + grid_axes = [] + scatter_axis = 0 + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @send_empty_grid_axes +func.func @send_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.send + %0 = shard.send %arg0 on @grid0 + grid_axes = [] + destination = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +shard.grid @grid4x4(shape = 4x4) +// CHECK-LABEL: func @test_halo_sizes +func.func @test_halo_sizes() -> !shard.sharding { + %c2_i64 = arith.constant 2 : i64 + // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !shard.sharding + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !shard.sharding + return %sharding : !shard.sharding +} + +// CHECK-LABEL: func @test_shard_offs +func.func @test_shard_offs() -> !shard.sharding { + %c2_i64 = arith.constant 2 : i64 + // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !shard.sharding + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !shard.sharding + return %sharding : !shard.sharding +} + +// CHECK-LABEL: func @test_duplicate_shardops +func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharded]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} + +// CHECK-LABEL: func @test_duplicate_shardops_diff +func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_0:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding + %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> + %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] : tensor<1024x1024xf32> + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharded_1]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir new file mode 100644 index 0000000000000..5a0f35b53a129 --- /dev/null +++ b/mlir/test/Dialect/Shard/folding.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s + +shard.grid @grid0(shape = 4x?x2) +shard.grid @grid1(shape = 2x3) + +// CHECK-LABEL: func.func @grid_shape_op_folding +func.func @grid_shape_op_folding() -> (index, index) { + // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index + %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index + // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid +func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) { + // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index + %0:2 = shard.grid_shape @grid1 : index, index + // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir similarity index 63% rename from mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir rename to mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir index dd2eee2f7def8..0d8d99752620a 100644 --- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir +++ b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir @@ -2,25 +2,25 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + shard.grid @grid(shape = 1) {sym_visibility = "private"} func.func @test_forward() -> tensor<6x6xi32> { %c1_i32 = arith.constant 1 : i32 // CHECK: tensor.empty() %0 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]] - %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32> + // CHECK-COUNT-3: shard.sharding @grid split_axes = {{\[\[0}}]] + %sharding_row = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %annotated_row = shard.shard %0 to %sharding_row : tensor<6x6xi32> %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32> %2 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]] + // CHECK-COUNT-4: shard.sharding @grid split_axes = {{\[\[1}}]] %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1 : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { ^bb0(%in: i32, %in_2: i32, %out: i32): %9 = arith.addi %in, %in_2 : i32 linalg.yield %9 : i32 } -> tensor<6x6xi32> - %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding - %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32> + %sharding_col = shard.sharding @grid split_axes = [[1]] : !shard.sharding + %annotated_col = shard.shard %3 to %sharding_col : tensor<6x6xi32> // CHECK: return return %annotated_col : tensor<6x6xi32> } diff --git a/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir new file mode 100644 index 0000000000000..3cda9eaa365fd --- /dev/null +++ b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { + shard.grid @grid(shape = 1) {sym_visibility = "private"} + func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor) attributes {llvm.emit_c_interface} { + %c1_i32 = arith.constant 1 : i32 + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> + %0 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[v1:%.*]] = linalg.fill ins + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %1 to %sharding : tensor<6x6xi32> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> + %3 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> + // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} + // CHECK-SAME: ins([[vsharded_3]], [[vsharded_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharded_5]] : tensor<6x6xi32>) { + // CHECK: [[vsharding_6:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_7:%.*]] = shard.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharded, %sharded + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { + ^bb0(%in: i32, %in_2: i32, %out: i32): + %9 = arith.addi %in, %in_2 : i32 + linalg.yield %9 : i32 + } -> tensor<6x6xi32> + %c0_i32 = arith.constant 0 : i32 + %6 = tensor.empty() : tensor + %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor) -> tensor + // CHECK: [[vreduced:%.*]] = linalg.reduce ins + // CHECK: [[vsharding_12:%.*]] = shard.sharding @grid split_axes = [] : !shard.sharding + // CHECK: [[vsharded_13:%.*]] = shard.shard [[vreduced]] to [[vsharding_12]] : tensor + %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor) dimensions = [0, 1] + (%in: i32, %init: i32) { + %9 = arith.addi %in, %init : i32 + linalg.yield %9 : i32 + } + // CHECK: [[vsharding_14:%.*]] = shard.sharding @grid split_axes = {{\[\[}}]] : !shard.sharding + %sharding_0 = shard.sharding @grid split_axes = [[]] : !shard.sharding + // CHECK: [[vsharded_15:%.*]] = shard.shard [[vsharded_13]] to [[vsharding_14]] annotate_for_users : tensor + %sharded_1 = shard.shard %reduced to %sharding_0 annotate_for_users : tensor + return %sharded, %4, %sharded_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor + } +} diff --git a/mlir/test/Dialect/Shard/inlining.mlir b/mlir/test/Dialect/Shard/inlining.mlir new file mode 100644 index 0000000000000..ce664b31abf7a --- /dev/null +++ b/mlir/test/Dialect/Shard/inlining.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -inline %s | FileCheck %s + +shard.grid @grid0(shape = 4x?x2) + +func.func private @grid_to_inline() -> (index, index) { + %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index + return %0#0, %0#1 : index, index +} +// CHECK-LABEL: func.func @main +func.func @main() -> (index, index) { + // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = shard.grid_shape @grid0 axes = [2, 1] : index + %0:2 = func.call @grid_to_inline() : () -> (index, index) + // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1 + return %0#0, %0#1 : index, index +} diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir similarity index 57% rename from mlir/test/Dialect/Mesh/invalid.mlir rename to mlir/test/Dialect/Shard/invalid.mlir index 2656332942382..6acac971164ed 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Shard/invalid.mlir @@ -1,55 +1,55 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s -// expected-error@+1 {{rank of mesh is expected to be a positive integer}} -mesh.mesh @mesh0(shape = []) +// expected-error@+1 {{rank of grid is expected to be a positive integer}} +shard.grid @grid0(shape = []) // ----- -// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} -mesh.mesh @mesh0(shape = -1) +// expected-error@+1 {{custom op 'shard.grid' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} +shard.grid @grid0(shape = -1) // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_duplicated_different_subarray( +func.func @grid_axis_duplicated_different_subarray( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis duplicated}} - %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis duplicated}} + %s = shard.sharding @grid0 split_axes = [[0], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_duplicated_same_subarray( +func.func @grid_axis_duplicated_same_subarray( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis duplicated}} - %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis duplicated}} + %s = shard.sharding @grid0 split_axes = [[0, 0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_negtive_in_split_part( +func.func @grid_axis_negtive_in_split_part( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis is expected to be non-negative}} - %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis is expected to be non-negative}} + %s = shard.sharding @grid0 split_axes = [[-1]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { - // expected-error@+1 {{custom op 'mesh.sharding' invalid kind of attribute specified}} - %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{custom op 'shard.sharding' invalid kind of attribute specified}} + %s = shard.sharding @a::@b split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } @@ -57,8 +57,8 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{halo sizes must be specified for all split axes}} - %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } @@ -66,292 +66,292 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) { func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{halo sizes and shard offsets are mutually exclusive}} - %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh_dyn(shape = ?x?) -func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) { - // expected-error@+1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}} - %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +shard.grid @grid_dyn(shape = ?x?) +func.func @sharding_dyn_grid_and_sizes(%arg0 : tensor<4x8xf32>) { + // expected-error@+1 {{sharded dims offsets are not allowed for device grids with dynamic shape}} + %s = shard.sharding @grid_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{sharded dims offsets has wrong size}} - %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 4) +shard.grid @grid0(shape = 4) func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{sharded dims offsets must be non-decreasing}} - %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index +func.func @grid_shape_grid_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0:2 = shard.grid_shape @grid0 axes = [0, 2] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index +func.func @grid_shape_duplicate_grid_axis() -> (index, index, index) { + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0:3 = shard.grid_shape @grid0 axes = [0, 2, 0] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_shape_wrong_number_of_results() -> (index, index) { +func.func @grid_shape_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} - %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index + %0:2 = shard.grid_shape @grid0 axes = [0] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { +func.func @grid_shape_wrong_number_of_results_empty_grid_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} - %0:2 = mesh.mesh_shape @mesh0 : index, index + %0:2 = shard.grid_shape @grid0 : index, index return %0#0, %0#1 : index, index } // ----- -func.func @mesh_shape_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index +func.func @grid_shape_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.grid_shape @this_grid_symbol_does_not_exist : index return %0#0 : index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index +func.func @process_multi_index_grid_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0:2 = shard.process_multi_index on @grid0 axes = [0, 2] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index +func.func @process_multi_index_duplicate_grid_axis() -> (index, index, index) { + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0:3 = shard.process_multi_index on @grid0 axes = [0, 2, 0] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @process_multi_index_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} - %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index + %0:2 = shard.process_multi_index on @grid0 axes = [0] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) { +func.func @process_multi_index_wrong_number_of_results_empty_grid_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} - %0:2 = mesh.process_multi_index on @mesh0 : index, index + %0:2 = shard.process_multi_index on @grid0 : index, index return %0#0, %0#1 : index, index } // ----- -func.func @process_multi_index_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index +func.func @process_multi_index_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.process_multi_index on @this_grid_symbol_does_not_exist : index return %0 : index } // ----- -func.func @process_linear_index_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index +func.func @process_linear_index_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.process_linear_index on @this_grid_symbol_does_not_exist : index return %0 : index } // ----- -func.func @all_reduce_invalid_mesh_symbol( +func.func @all_reduce_invalid_grid_symbol( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_reduce %arg0 on @this_grid_symbol_does_not_exist reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_invalid_mesh_axis( +func.func @all_reduce_invalid_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [2] reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_duplicate_mesh_axis( +func.func @all_reduce_duplicate_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1, 0] reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @all_reduce_invalid_tensor_dimension_size( %arg0 : tensor<4xf32>) -> tensor<5xf64> { - // expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}} - %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64> + // expected-error@+1 {{'shard.all_reduce' op requires the same shape for all operands and results}} + %0 = shard.all_reduce %arg0 on @grid0 : tensor<4xf32> -> tensor<5xf64> return %0 : tensor<5xf64> } // ----- -func.func @all_gather_invalid_mesh_symbol( +func.func @all_gather_invalid_grid_symbol( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0 + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_gather %arg0 on @this_grid_symbol_does_not_exist gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_gather_invalid_mesh_axis( +func.func @all_gather_invalid_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0 + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_duplicate_mesh_axis( +func.func @all_reduce_duplicate_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0 + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2, 2] gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 : tensor<3x4xf32> -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1x2) +shard.grid @grid0(shape = 1x2) func.func @all_gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor) -> tensor<3xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 gather_axis = 0 : tensor -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 : tensor<3xf32> -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 : tensor<3xf32> -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @all_slice_duplicate_mesh_axis( +func.func @all_slice_duplicate_grid_axis( %arg0 : tensor) -> tensor { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0, 0] slice_axis = 0 : tensor -> tensor return %0 : tensor @@ -359,12 +359,12 @@ func.func @all_slice_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_dynamic_dimension( %arg0 : tensor) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.all_slice %arg0 on @mesh0 + %0 = shard.all_slice %arg0 on @grid0 slice_axis = 0 : tensor -> tensor<2xf32> return %0 : tensor<2xf32> @@ -372,12 +372,12 @@ func.func @all_slice_invalid_dynamic_dimension( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<3xf32> -> tensor<2xf32> return %0 : tensor<2xf32> @@ -385,12 +385,12 @@ func.func @all_slice_invalid_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4xf32> -> tensor return %0 : tensor @@ -398,10 +398,10 @@ func.func @all_slice_invalid_operand_static_dimension_size( // ----- -func.func @all_to_all_invalid_mesh_symbol( +func.func @all_to_all_invalid_grid_symbol( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_to_all %arg0 on @this_grid_symbol_does_not_exist split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -409,12 +409,12 @@ func.func @all_to_all_invalid_mesh_symbol( // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) -func.func @all_to_all_duplicate_mesh_axis( +func.func @all_to_all_duplicate_grid_axis( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 0] split_axis = 0 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -422,12 +422,12 @@ func.func @all_to_all_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = ?x1) +shard.grid @grid0(shape = ?x1) func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -435,12 +435,12 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de // ----- -mesh.mesh @mesh0(shape = 1x1) +shard.grid @grid0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor) -> tensor<3x?xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1] split_axis = 0 concat_axis = 1 : tensor -> tensor<3x?xi8> return %0 : tensor<3x?xi8> @@ -448,12 +448,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna // ----- -mesh.mesh @mesh0(shape = 1x1) +shard.grid @grid0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<3x?xi8>) -> tensor { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1] split_axis = 0 concat_axis = 1 : tensor<3x?xi8> -> tensor return %0 : tensor @@ -461,12 +461,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x2xi8> -> tensor<1x7xi8> return %0 : tensor<1x7xi8> @@ -474,12 +474,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x2xi8> -> tensor<2x6xi8> return %0 : tensor<2x6xi8> @@ -487,12 +487,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -500,12 +500,12 @@ func.func @broadcast_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -513,12 +513,12 @@ func.func @broadcast_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.broadcast' op failed to verify that all of {input, result} have same element type}} + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -526,84 +526,84 @@ func.func @broadcast_different_input_and_result_type( // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_wrong_return_element_type( %arg0 : tensor<1xf32>) -> tensor<1xi8> { - // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] + // expected-error@+1 {{'shard.gather' op failed to verify that all of {input, result} have same element type}} + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0] : (tensor<1xf32>) -> tensor<1xi8> return %0 : tensor<1xi8> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0] : (tensor<3x4xf32>) -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1x2) +shard.grid @grid0(shape = 1x2) func.func @gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 root = [0] : (tensor<3x4xf32>) -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor) -> tensor<3xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = [] + %0 = shard.gather %arg0 on @grid0 gather_axis = 0 root = [] : (tensor) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 root = [0] : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 root = [0] : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @gather_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<6xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [3] : (tensor<2xi8>) -> tensor<6xi8> return %0 : tensor<6xi8> @@ -611,12 +611,12 @@ func.func @gather_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @gather_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -624,12 +624,12 @@ func.func @gather_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_source_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -637,12 +637,12 @@ func.func @receive_source_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_source_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -650,12 +650,12 @@ func.func @receive_source_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.recv' op failed to verify that all of {input, result} have same element type}} + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -663,12 +663,12 @@ func.func @receive_different_input_and_result_type( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -676,12 +676,12 @@ func.func @reduce_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -689,12 +689,12 @@ func.func @reduce_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_different_input_and_result_shape( %arg0 : tensor<2xi8>) -> tensor<3xi16> { - // expected-error@+1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.reduce' op failed to verify that all of {input, result} have same shape}} + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [2] : (tensor<2xi8>) -> tensor<3xi16> return %0 : tensor<3xi16> @@ -702,60 +702,60 @@ func.func @reduce_different_input_and_result_shape( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @reduce_scatter_duplicate_mesh_axis( +func.func @reduce_scatter_duplicate_grid_axis( %arg0 : tensor) -> tensor { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0 + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0 : tensor -> tensor return %0 : tensor } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_dynamic_dimension( %arg0 : tensor) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0 : tensor -> tensor<2xf64> return %0 : tensor<2xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 : tensor<3xf32> -> tensor<2xf64> return %0 : tensor<2xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 : tensor<4xf32> -> tensor return %0 : tensor } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @scatter_duplicate_mesh_axis( +func.func @scatter_duplicate_grid_axis( %arg0 : tensor) -> tensor { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0 root = [0, 0] : (tensor) -> tensor return %0 : tensor @@ -763,12 +763,12 @@ func.func @scatter_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_dynamic_dimension( %arg0 : tensor) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.scatter %arg0 on @mesh0 + %0 = shard.scatter %arg0 on @grid0 scatter_axis = 0 root = [] : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -776,12 +776,12 @@ func.func @scatter_invalid_dynamic_dimension( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<3xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -789,12 +789,12 @@ func.func @scatter_invalid_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<4xf32>) -> tensor return %0 : tensor @@ -802,12 +802,12 @@ func.func @scatter_invalid_operand_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @scatter_root_dimension_out_of_bounds( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [3] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> @@ -815,12 +815,12 @@ func.func @scatter_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @scatter_root_wrong_number_dimensions( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [2, 2] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> @@ -828,12 +828,12 @@ func.func @scatter_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_destination_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -841,12 +841,12 @@ func.func @send_destination_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_destination_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -854,12 +854,12 @@ func.func @send_destination_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.send' op failed to verify that all of {input, result} have same element type}} + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -867,10 +867,10 @@ func.func @send_different_input_and_result_type( // ----- -func.func @shift_invalid_mesh_symbol( +func.func @shift_invalid_grid_symbol( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.shift %arg0 on @this_grid_symbol_does_not_exist shift_axis = 0 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -878,12 +878,12 @@ func.func @shift_invalid_mesh_symbol( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @shift_invalid_mesh_axis( +func.func @shift_invalid_grid_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2] + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [2] shift_axis = 2 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -891,12 +891,12 @@ func.func @shift_invalid_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @shift_duplicate_mesh_axis( +func.func @shift_duplicate_grid_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 1, 0] shift_axis = 0 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -904,12 +904,12 @@ func.func @shift_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @shift_invalid_tensor_dimension_size( %arg0 : tensor<4xi8>) -> tensor<5xi8> { - // expected-error@+1 {{'mesh.shift' op requires the same shape for all operands and results}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.shift' op requires the same shape for all operands and results}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0] shift_axis = 0 offset = 2 : tensor<4xi8> -> tensor<5xi8> return %0 : tensor<5xi8> @@ -917,12 +917,12 @@ func.func @shift_invalid_tensor_dimension_size( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @shift_invalid_shift_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping grid axes.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0] shift_axis = 1 offset = 2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir similarity index 55% rename from mlir/test/Dialect/Mesh/ops.mlir rename to mlir/test/Dialect/Shard/ops.mlir index c354de514fba8..5265dadd2a845 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Shard/ops.mlir @@ -1,176 +1,176 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 2x2x4) +// CHECK: shard.grid @grid0 +shard.grid @grid0(shape = 2x2x4) -// CHECK: mesh.mesh @mesh1(shape = 4x?) -mesh.mesh @mesh1(shape = 4x?) +// CHECK: shard.grid @grid1(shape = 4x?) +shard.grid @grid1(shape = 4x?) -// CHECK: mesh.mesh @mesh2(shape = ?x4) -mesh.mesh @mesh2(shape = ?x4) +// CHECK: shard.grid @grid2(shape = ?x4) +shard.grid @grid2(shape = ?x4) -// CHECK: mesh.mesh @mesh3(shape = ?x?) -mesh.mesh @mesh3(shape = ?x?) +// CHECK: shard.grid @grid3(shape = ?x?) +shard.grid @grid3(shape = ?x?) -mesh.mesh @mesh4(shape = 3) +shard.grid @grid4(shape = 3) -// CHECK: mesh.mesh @mesh5(shape = ?) -mesh.mesh @mesh5(shape = ?) +// CHECK: shard.grid @grid5(shape = ?) +shard.grid @grid5(shape = ?) -// CHECK-LABEL: func @mesh_shard_op_fully_replicated +// CHECK-LABEL: func @grid_shard_op_fully_replicated // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +func.func @grid_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_1st_dim +// CHECK-LABEL: func @grid_shard_op_1st_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding +func.func @grid_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_2nd_dim +// CHECK-LABEL: func @grid_shard_op_2nd_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding - %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +func.func @grid_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid1 split_axes = {{\[\[}}], [0]] : !shard.sharding + %s = shard.sharding @grid1 split_axes = [[], [0]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim -func.func @mesh_shard_op_1st_and_3rd_dim( +// CHECK-LABEL: func @grid_shard_op_1st_and_3rd_dim +func.func @grid_shard_op_1st_and_3rd_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32> %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding - %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32> + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid3 split_axes = {{\[\[}}0], [], [1]] : !shard.sharding + %s = shard.sharding @grid3 split_axes = [[0], [], [1]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } -// CHECK-LABEL: func @mesh_shard_op_two_users +// CHECK-LABEL: func @grid_shard_op_two_users // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> +func.func @grid_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding - %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32> - // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding - %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32> - // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding - %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding - %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32> + // CHECK-NEXT: %[[V0:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding + %s0 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<4x8xf32> + // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}1]] : !shard.sharding + %s1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<4x8xf32> + // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}2]] : !shard.sharding + %s2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding + %2 = shard.shard %0 to %s2 annotate_for_users : tensor<4x8xf32> return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_halo_sizes -func.func @mesh_shard_halo_sizes() -> () { +// CHECK-LABEL: func @grid_shard_halo_sizes +func.func @grid_shard_halo_sizes() -> () { // CHECK: %[[C3:.*]] = arith.constant 3 : i64 %c3 = arith.constant 3 : i64 - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding - %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !shard.sharding + %sharding1 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [1, 4] : !shard.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !shard.sharding + %sharding2 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [4, %c3] : !shard.sharding return } -// CHECK-LABEL: func @mesh_shard_dims_sizes -func.func @mesh_shard_dims_sizes() -> () { +// CHECK-LABEL: func @grid_shard_dims_sizes +func.func @grid_shard_dims_sizes() -> () { // CHECK: %[[C3:.*]] = arith.constant 3 : i64 %c3 = arith.constant 3 : i64 - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding - %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding - %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding + %sharding1 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !shard.sharding + %sharding2 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !shard.sharding return } -// CHECK-LABEL: func @mesh_shard_shape -func.func @mesh_shard_shape() { +// CHECK-LABEL: func @grid_shard_shape +func.func @grid_shard_shape() { // CHECK: %[[C3:.*]] = arith.constant 3 : index %c3 = arith.constant 3 : index - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding - // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]] + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding + // CHECK-NEXT: shard.shard_shape dims = [8, %[[C3]] // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]] // CHECK-SAME: ] : index, index - %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index - // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index - %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index + %shp:2 = shard.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index + // CHECK-NEXT: shard.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index + %shp1:2 = shard.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index return } -// CHECK-LABEL: func @mesh_get_sharding +// CHECK-LABEL: func @grid_get_sharding // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { - // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding - %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding - return %0 : !mesh.sharding +func.func @grid_get_sharding(%arg0 : tensor<4x8xf32>) -> !shard.sharding { + // CHECK-NEXT: shard.get_sharding %[[ARG]] : tensor<4x8xf32> -> !shard.sharding + %0 = shard.get_sharding %arg0 : tensor<4x8xf32> -> !shard.sharding + return %0 : !shard.sharding } -// CHECK-LABEL: func @mesh_shape -func.func @mesh_shape() -> (index, index) { - // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index - %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index +// CHECK-LABEL: func @grid_shape +func.func @grid_shape() -> (index, index) { + // CHECK: %[[RES:.*]]:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index + %0:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } -// CHECK-LABEL: func @mesh_shape_default_axes -func.func @mesh_shape_default_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index - %0:3 = mesh.mesh_shape @mesh0 : index, index, index +// CHECK-LABEL: func @grid_shape_default_axes +func.func @grid_shape_default_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index + %0:3 = shard.grid_shape @grid0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } -// CHECK-LABEL: func @mesh_shape_empty_axes -func.func @mesh_shape_empty_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index - %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index +// CHECK-LABEL: func @grid_shape_empty_axes +func.func @grid_shape_empty_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index + %0:3 = shard.grid_shape @grid0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index func.func @process_multi_index() -> (index, index) { - // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index - %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index + // CHECK: %[[RES:.*]]:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index + %0:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } // CHECK-LABEL: func @process_multi_index_default_axes func.func @process_multi_index_default_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index + %0:3 = shard.process_multi_index on @grid0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index_empty_axes func.func @process_multi_index_empty_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { - // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index - %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: %[[RES:.*]] = shard.process_linear_index on @grid0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[RES]] : index return %0 : index } @@ -179,9 +179,9 @@ func.func @process_linear_index() -> index { func.func @all_reduce( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> { - // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max + // CHECK-NEXT: shard.all_reduce %[[ARG]] on @grid0 grid_axes = [1, 0] reduction = max // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1, 0] reduction = max : tensor<3x4xf32> -> tensor<3x4xf64> return %0 : tensor<3x4xf64> } @@ -190,9 +190,9 @@ func.func @all_reduce( func.func @all_gather( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32> - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x16xf32> return %0 : tensor<3x16xf32> } @@ -201,20 +201,20 @@ func.func @all_gather( func.func @all_gather_dynamic_dims_in_tensor( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor -> tensor - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor -> tensor return %0 : tensor } -// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh -func.func @all_gather_dynamic_dims_in_mesh( +// CHECK-LABEL: func @all_gather_dynamic_dims_in_grid +func.func @all_gather_dynamic_dims_in_grid( // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32> %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid3 grid_axes = [1] gather_axis = 1 // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32> - %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid3 grid_axes = [1] gather_axis = 1 : tensor<5x6xf32> -> tensor<5x?xf32> return %0 : tensor<5x?xf32> } @@ -223,10 +223,10 @@ func.func @all_gather_dynamic_dims_in_mesh( func.func @all_slice_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { - // CHECK-NEXT: mesh.all_slice %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1 + // CHECK-NEXT: shard.all_slice %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] slice_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32> - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1 + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1 : tensor<3x4xf32> -> tensor<3x1xf32> return %0 : tensor<3x1xf32> } @@ -235,10 +235,10 @@ func.func @all_slice_static_dimensions( func.func @all_slice_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { - // CHECK-NEXT: mesh.all_slice %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + // CHECK-NEXT: shard.all_slice %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] slice_axis = 0 // CHECK-SAME: : tensor -> tensor - %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + %0 = shard.all_slice %arg0 on @grid3 grid_axes = [0, 1] slice_axis = 0 : tensor -> tensor return %0 : tensor } @@ -247,10 +247,10 @@ func.func @all_slice_dynamic_dimensions( func.func @all_to_all( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -260,10 +260,10 @@ func.func @all_to_all( func.func @all_to_all_dynamic_dims_in_result( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x?xi8> return %0 : tensor<3x?xi8> @@ -273,10 +273,10 @@ func.func @all_to_all_dynamic_dims_in_result( func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8> %arg0 : tensor<3xi8>) -> tensor<3xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: @grid4 split_axis = 0 concat_axis = 0 // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 0 concat_axis = 0 : tensor<3xi8> -> tensor<3xi8> return %0 : tensor<3xi8> @@ -286,10 +286,10 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( func.func @all_to_all_non_divisible_split_axis_size( // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8> %arg0 : tensor<2x3xi8>) -> tensor { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1 // CHECK-SAME: : tensor<2x3xi8> -> tensor - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1 : tensor<2x3xi8> -> tensor return %0 : tensor @@ -299,11 +299,11 @@ func.func @all_to_all_non_divisible_split_axis_size( func.func @broadcast_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.broadcast %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.broadcast %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8> - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<3x6xi8>) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -316,11 +316,11 @@ func.func @broadcast_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.broadcast %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.broadcast %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8> - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2] root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -330,12 +330,12 @@ func.func @broadcast_dynamic_root( func.func @gather_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> { - // CHECK-NEXT: mesh.gather %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.gather %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8> - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2] gather_axis = 0 root = [0, 1] : (tensor<3x6xi8>) -> tensor<24x6xi8> @@ -349,12 +349,12 @@ func.func @gather_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<24x6xi8> { - // CHECK-NEXT: mesh.gather %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.gather %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8> - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2] gather_axis = 0 root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<24x6xi8> @@ -365,11 +365,11 @@ func.func @gather_dynamic_root( func.func @receive_static_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.recv %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: source = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] source = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -382,11 +382,11 @@ func.func @receive_dynamic_source( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.recv %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: source = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] source = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -396,9 +396,9 @@ func.func @receive_dynamic_source( func.func @receive_no_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG]] + // CHECK-NEXT: shard.recv %[[ARG]] // CHECK-NOT: source - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } @@ -407,11 +407,11 @@ func.func @receive_no_source( func.func @reduce_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.reduce %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -424,11 +424,11 @@ func.func @reduce_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.reduce %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -438,11 +438,11 @@ func.func @reduce_dynamic_root( func.func @reduce_different_return_element_type( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // CHECK-NEXT: mesh.reduce %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -452,10 +452,10 @@ func.func @reduce_different_return_element_type( func.func @reduce_scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> { - // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1 + // CHECK-NEXT: shard.reduce_scatter %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64> - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2] + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2] reduction = max scatter_axis = 1 : tensor<3x4xf32> -> tensor<3x1xf64> return %0 : tensor<3x1xf64> @@ -465,10 +465,10 @@ func.func @reduce_scatter_static_dimensions( func.func @reduce_scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { - // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 + // CHECK-NEXT: shard.reduce_scatter %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0 // CHECK-SAME: : tensor -> tensor - %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0 : tensor -> tensor return %0 : tensor } @@ -477,11 +477,11 @@ func.func @reduce_scatter_dynamic_dimensions( func.func @scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { - // CHECK-NEXT: mesh.scatter %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] + // CHECK-NEXT: shard.scatter %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] // CHECK-SAME: scatter_axis = 1 root = [1] // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32> - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [2] scatter_axis = 1 root = [1] : (tensor<3x4xf32>) -> tensor<3x1xf32> return %0 : tensor<3x1xf32> @@ -491,11 +491,11 @@ func.func @scatter_static_dimensions( func.func @scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { - // CHECK-NEXT: mesh.scatter %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] + // CHECK-NEXT: shard.scatter %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] // CHECK-SAME: scatter_axis = 0 root = [1, 2] // CHECK-SAME: : (tensor) -> tensor - %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1] + %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0 root = [1, 2] : (tensor) -> tensor return %0 : tensor @@ -508,12 +508,12 @@ func.func @scatter_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<1xi8> { - // CHECK-NEXT: mesh.scatter %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.scatter %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: scatter_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8> - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2] scatter_axis = 0 root = [1, %arg1] : (tensor<8xi8>, index) -> tensor<1xi8> @@ -524,11 +524,11 @@ func.func @scatter_dynamic_root( func.func @send_static_destination( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.send %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.send %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: destination = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2] destination = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -541,11 +541,11 @@ func.func @send_dynamic_destination( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.send %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.send %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: destination = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2] destination = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -555,11 +555,11 @@ func.func @send_dynamic_destination( func.func @shift( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.shift %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.shift %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: shift_axis = 2 offset = -2 rotate // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8> - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 2] shift_axis = 2 offset = -2 rotate : tensor<2xi8> -> tensor<2xi8> return %0 : tensor<2xi8> @@ -570,16 +570,16 @@ func.func @update_halo( // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 - // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0 + // CHECK-NEXT: %[[UH1:.*]] = shard.update_halo %[[ARG]] on @grid0 // CHECK-SAME: split_axes = {{\[\[}}0]] // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> %c2 = arith.constant 2 : i64 - %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + %uh1 = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, %c2] : memref<12x12xi8> - // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0 + // CHECK-NEXT: %[[UH2:.*]] = shard.update_halo %[[UH1]] on @grid0 // CHECK-SAME: split_axes = {{\[\[}}0], [1]] // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> - %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]] + %uh2 = shard.update_halo %uh1 on @grid0 split_axes = [[0], [1]] halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> return } diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir new file mode 100644 index 0000000000000..c2572cc3b987b --- /dev/null +++ b/mlir/test/Dialect/Shard/partition.mlir @@ -0,0 +1,317 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_1d(shape = 2) + +// CHECK-LABEL: func @return_sharding +func.func @return_sharding( + // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> (tensor<1xf32>, !shard.sharding) { +) -> (tensor<2xf32>, !shard.sharding) { + %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding + %r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding + // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding + return %sharded, %r : tensor<2xf32>, !shard.sharding +} + +// CHECK-LABEL: func @full_replication +func.func @full_replication( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<2xi8> { +) -> tensor<2xi8> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: return %[[ARG]] : tensor<2xi8> + return %1 : tensor<2xi8> +} + +// CHECK-LABEL: func @sharding_triplet +func.func @sharding_triplet( + // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> tensor<2xf32> { +) -> tensor<2xf32> { + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32> + %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32> + %ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded_0 = shard.shard %sharded to %ssharded_0 annotate_for_users : tensor<2xf32> + %ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %sharded_1 = shard.shard %sharded_0 to %ssharded_1 : tensor<2xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<2xf32> + return %sharded_1 : tensor<2xf32> +} + + +// CHECK-LABEL: func @move_split_axis +func.func @move_split_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8> + %arg0: tensor<2x2xi8> +// CHECK-SAME: -> tensor<2x1xi8> { +) -> tensor<2x2xi8> { + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[ARG]] on @grid_1d + // CHECK-SAME: grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2x2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2x2xi8> + // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8> + return %1 : tensor<2x2xi8> +} + +// CHECK-LABEL: func @non_tensor_value +func.func @non_tensor_value( + // CHECK-SAME: %[[ARG:.*]]: i8 + %arg0: i8 +// CHECK-SAME: -> i8 { +) -> i8 { + // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8 + %0 = arith.addi %arg0, %arg0 : i8 + // CHECK: return %[[RES]] : i8 + return %0 : i8 +} + +// CHECK-LABEL: func @unary_elementwise +func.func @unary_elementwise( + // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %4 : tensor<2xi8> +} + +// full replication -> shard axis -> abs -> shard axis -> full replication +// CHECK-LABEL: func @unary_elementwise_with_resharding +func.func @unary_elementwise_with_resharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<2xi8> { +) -> tensor<2xi8> { + // CHECK: %[[SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RES:.*]] = shard.all_gather %[[ABS]] on @grid_1d + // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<2xi8> + return %4 : tensor<2xi8> +} + +// CHECK-LABEL: func @binary_elementwise +func.func @binary_elementwise( + // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>, + %arg0: tensor<2xi8>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8> + %arg1: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %sarg0_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2xi8> + %sop_arg0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_arg0 = shard.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8> + %sarg1_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %arg1_sharded = shard.shard %arg1 to %sarg1_sharded : tensor<2xi8> + %sop_arg1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_arg1 = shard.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8> + // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> + %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> + %sop_res_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_res_sharded = shard.shard %op_res to %sop_res_sharded : tensor<2xi8> + %sres = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %res = shard.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %res : tensor<2xi8> +} + +// reshard +// abs +// reshard +// abs +// reshard +// CHECK-LABEL: func @multiple_chained_ops +func.func @multiple_chained_ops( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + // CHECK: %[[RESHARD1:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RESHARD2:.*]] = shard.all_gather %[[ABS1]] on @grid_1d + // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> + %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RESHARD3:.*]] = shard.all_slice %[[ABS2]] on @grid_1d grid_axes = [0] slice_axis = 0 : + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s6 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %6 = shard.shard %5 to %s6 : tensor<2xi8> + %s7 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %7 = shard.shard %6 to %s7 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RESHARD3]] : tensor<1xi8> + return %7 : tensor<2xi8> +} + +// CHECK-LABEL: func @incomplete_sharding +func.func @incomplete_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> + %arg0: tensor<8x16xf32> +// CHECK-SAME: -> tensor<4x16xf32> { +) -> tensor<8x16xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> + // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %2 = shard.shard %1 to %s2 : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<4x16xf32> + return %2 : tensor<8x16xf32> +} + +shard.grid @grid_1d_4(shape = 4) + +// CHECK-LABEL: func @ew_chain_with_halo +func.func @ew_chain_with_halo( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> + %arg0: tensor<8x16xf32>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg1: tensor<1xf32>, + // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg2: tensor<1xf32>) + // CHECK-SAME: -> tensor<5x16xf32> + -> tensor<8x16xf32> { + %ssharded = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded annotate_for_users : tensor<8x16xf32> + // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> + %0 = tosa.tanh %sharded : (tensor<8x16xf32>) -> tensor<8x16xf32> + %ssharded_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_0 = shard.shard %0 to %ssharded_0 : tensor<8x16xf32> + %ssharded_1 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_1 = shard.shard %sharded_0 to %ssharded_1 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> + %1 = tosa.abs %sharded_1 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %ssharded_2 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_2 = shard.shard %1 to %ssharded_2 : tensor<8x16xf32> + %ssharded_4 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_4 = shard.shard %sharded_2 to %ssharded_4 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32> + %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding + %zero_point_1 = shard.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32> + %zero_point_2 = shard.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32> + %2 = tosa.negate %sharded_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> + %ssharded_5 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_5 = shard.shard %2 to %ssharded_5 : tensor<8x16xf32> + %ssharded_6 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_6 = shard.shard %sharded_5 to %ssharded_6 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32> + return %sharded_6 : tensor<8x16xf32> +} + +// CHECK-LABEL: func @test_shard_update_halo +// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64> +func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] : !shard.sharding + // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> + // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> + // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> + %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64> + %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !shard.sharding + %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64> + %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> + // CHECK: return %[[UH]] : tensor<304x1200xi64> + return %sharded_3 : tensor<1200x1200xi64> +} + +shard.grid @grid4x4(shape = 4x4) +// CHECK-LABEL: func @test_shard_update_halo2d +// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64> +func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] : !shard.sharding + // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> + // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> + // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> + %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64> + %sharding_0 = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !shard.sharding + %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64> + %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> + // CHECK: return %[[UH]] : tensor<303x307xi64> + return %sharded_3 : tensor<1200x1200xi64> +} + +shard.grid @grid(shape = 2) +// CHECK-LABEL: func.func @test_reduce_0d( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> +func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor) { + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> + %4 = tensor.empty() : tensor + %sharding_out = shard.sharding @grid split_axes = [[]] : !shard.sharding + %sharded_out = shard.shard %4 to %sharding_out : tensor + %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> + // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) + %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor) dimensions = [0, 1] + (%in: i32, %init: i32) { + %6 = arith.addi %in, %init : i32 + linalg.yield %6 : i32 + } + // CHECK: %[[all_reduce:.*]] = shard.all_reduce %[[reduced]] on @grid grid_axes = [0] : tensor -> tensor + %sharded_red = shard.shard %reduced to %sharding_out : tensor + %sharded_ret = shard.shard %sharded_red to %sharding_out annotate_for_users : tensor + // CHECK: return %[[all_reduce]] : tensor + return %sharded_ret : tensor +} + +// CHECK-LABEL: func.func @test_reduce_1d( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> +func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) { + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> + %4 = tensor.empty() : tensor<6xi32> + %sharded_out = shard.shard %4 to %sharding : tensor<6xi32> + %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> + // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) + %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %6 = arith.addi %in, %init : i32 + linalg.yield %6 : i32 + } + // CHECK-NOT: shard.all_reduce + %sharded_red = shard.shard %reduced to %sharding : tensor<6xi32> + %sharded_ret = shard.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32> + // CHECK: return %[[reduced]] : tensor<3xi32> + return %sharded_ret : tensor<6xi32> +} diff --git a/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir new file mode 100644 index 0000000000000..33c7a8f96464d --- /dev/null +++ b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -test-grid-process-multi-index-op-lowering %s | FileCheck %s + +shard.grid @grid2d(shape = ?x?) + +// CHECK-LABEL: func.func @multi_index_2d_grid +func.func @multi_index_2d_grid() -> (index, index) { + // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index + // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index + // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index + %0:2 = shard.process_multi_index on @grid2d : index, index + // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @multi_index_2d_grid_single_inner_axis +func.func @multi_index_2d_grid_single_inner_axis() -> index { + // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index + // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index + // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index + %0 = shard.process_multi_index on @grid2d axes = [0] : index + // CHECK: return %[[MULTI_IDX]]#0 : index + return %0 : index +} diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir new file mode 100644 index 0000000000000..ff9e8408aa7fd --- /dev/null +++ b/mlir/test/Dialect/Shard/resharding-partition.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt -test-grid-resharding-partition %s | FileCheck %s + +shard.grid @grid_1d(shape = 2) +shard.grid @grid_1d_dynamic(shape = ?) + +// CHECK-LABEL: func @same_source_and_target_sharding +func.func @same_source_and_target_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> + %arg0: tensor<2xf32> +) -> tensor<2xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xf32> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xf32> + // CHECK: return %[[ARG]] + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: func @identical_source_and_target_sharding +func.func @identical_source_and_target_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> + %arg0: tensor<2xf32> +) -> tensor<2xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xf32> + %1 = shard.shard %0 to %s0 annotate_for_users : tensor<2xf32> + // CHECK: return %[[ARG]] + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis +func.func @split_replicated_tensor_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> + %arg0: tensor<3x14xf32> +) -> tensor<3x14xf32> { + // CHECK: %[[ALL_SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 1 + // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> + // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<3x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<3x14xf32> + // CHECK: return %[[RESULT]] : tensor<3x14xf32> + return %1 : tensor<3x14xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic +func.func @split_replicated_tensor_axis_dynamic( + // CHECK-SAME: %[[ARG:.*]]: tensor + %arg0: tensor +) -> tensor { + // CHECK: %[[RESULT:.*]] = shard.all_slice %[[ARG]] on @grid_1d_dynamic grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor -> tensor + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[], [], []] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor + // CHECK: return %[[RESULT]] : tensor + return %1 : tensor +} + +// CHECK-LABEL: func @move_split_axis +func.func @move_split_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_axis_dynamic_grid +func.func @move_split_axis_dynamic_grid( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor -> tensor + // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor to tensor<10x?xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_dynamic_axis +func.func @move_split_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor + %arg0: tensor +) -> tensor { + // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[ARG]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor -> tensor + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor to tensor + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor + // CHECK: return %[[RES]] : tensor + return %1 : tensor +} + +// CHECK-LABEL: func @unshard_static_axis +func.func @unshard_static_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @unshard_static_last_axis +func.func @unshard_static_last_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @unshard_dynamic_axis +func.func @unshard_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor + %arg0: tensor +) -> tensor { + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor -> tensor + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor + // CHECK: return %[[ALL_GATHER]] : tensor + return %1 : tensor +} + +// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis +func.func @unshard_static_axis_on_dynamic_grid_axis( +// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] gather_axis = 0 : tensor -> tensor + // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} diff --git a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir similarity index 100% rename from mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir rename to mlir/test/Dialect/Shard/sharding-propagation-failed.mlir diff --git a/mlir/test/Dialect/Shard/sharding-propagation.mlir b/mlir/test/Dialect/Shard/sharding-propagation.mlir new file mode 100644 index 0000000000000..34aaf0598b3f0 --- /dev/null +++ b/mlir/test/Dialect/Shard/sharding-propagation.mlir @@ -0,0 +1,301 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s + +shard.grid @grid_2(shape = 2) +shard.grid @grid_1d(shape = ?) +shard.grid @grid_2d(shape = 2x4) +shard.grid @grid_3d(shape = ?x?x?) + +// CHECK-LABEL: func.func @element_wise_empty_sharding_info +func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: tosa.sigmoid + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: return + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_def +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V2]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_use +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: return %[[V2]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_graph_output +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_graph_input +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @arrow_structure +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> + %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V4:.*]] = shard.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] + // CHECK-NEXT: %[[V6:.*]] = shard.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> + %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP1:.*]] = shard.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[ZP2:.*]] = shard.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]] + // CHECK-NEXT: %[[V8:.*]] = shard.shard %[[V7]] to %[[S1]] : tensor<8x16xf32> + %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> + %s3 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V6]], %[[V8]] + return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> + // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n +// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [], [1]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32> + // CHECK-NEXT: return [[vsharded_5]] + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k +// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding + %s0 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32> + %arg0_s = shard.shard %arg0 to %s0 : tensor<2x16x8xf32> + // CHECK: [[vsharded_0:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> + // CHECK: [[vsharding_1:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding + // CHECK: [[vsharded_2:%.*]] = shard.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_3:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_4:%.*]] = shard.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + // CHECK: [[vsharding_5:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32> + %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK: return [[vsharded_6]] + return %0 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1], [0]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> + %s0 = shard.sharding @grid_2d split_axes = [[], [1], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> + // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> + // CHECK-NEXT: return %[[V3]] + return %2 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @resolve_conflicting_annotations +func.func @resolve_conflicting_annotations( + // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>, + %arg0: tensor<2x3xf32>, + // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>, + %arg1: tensor<3x2xf32>, + // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32> + %out_dps: tensor<2x2xf32> +// CHECK-SAME: ) -> tensor<2x2xf32> { +) -> tensor<2x2xf32> { + // CHECK: %[[SIN1_SHARDED1:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = shard.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32> + // CHECK: %[[SIN2_SHARDED:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = shard.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32> + // CHECK-NEXT: %[[IN2_SHARDED:.*]] = shard.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32> + // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = shard.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32> + %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> + // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>) + // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32> + %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) + outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: %[[RES:.*]] = shard.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32> + %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding + %res_sharded = shard.shard %res to %sres_sharded : tensor<2x2xf32> + // CHECK: return %[[RES]] : tensor<2x2xf32> + return %res_sharded : tensor<2x2xf32> +} + +// https://arxiv.org/abs/2211.05102 Figure 2(a) +// The sharding propagation results in unnecessary reshards, +// an optimization pass should be able to remove them. +// CHECK-LABEL: func.func @mlp_1d_weight_stationary +// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> +func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + %sharded0 = shard.shard %arg0 to %s0 : tensor<2x4x8xf32> + %sharded1 = shard.shard %arg1 to %s0 : tensor<2x8x32xf32> + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> + // CHECK: [[vsharded_0:%.*]] = shard.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32> + // CHECK: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32> + // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32> + %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + %sharding = shard.sharding @grid_1d split_axes = [[], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded_9:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32> + %sharded2 = shard.shard %arg2 to %sharding : tensor<2x32x8xf32> + // CHECK: [[vsharded_10:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> + // CHECK: [[v2:%.*]] = tosa.matmul + %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK: [[vsharded_12:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> + %s4 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + %4 = shard.shard %3 to %s4 : tensor<2x4x8xf32> + // CHECK: return [[vsharded_12]] + return %4 : tensor<2x4x8xf32> +} + +// https://arxiv.org/abs/2211.05102 Figure 2(b) +// The sharding propagation results in unnecessary reshards, +// an optimization pass should be able to remove them. +// CHECK-LABEL: func.func @mlp_2d_weight_stationary +// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> +func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding + %s0 = shard.sharding @grid_3d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> + %arg0_s = shard.shard %arg0 to %s0 : tensor<2x4x8xf32> + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [0], [1, 2]] : !shard.sharding + %s1 = shard.sharding @grid_3d split_axes = [[], [0], [1, 2]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32> + %arg1_s = shard.shard %arg1 to %s1 : tensor<2x8x32xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32> + // CHECK: [[vsharded_4:%.*]] = shard.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32> + %2 = shard.shard %1 to %s0 : tensor<2x4x32xf32> + // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[v1:%.*]] = tosa.sigmoid + // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32> + %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharding_9:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [1, 2], [0]] : !shard.sharding + %s2 = shard.sharding @grid_3d split_axes = [[], [1, 2], [0]] : !shard.sharding + // CHECK: [[vsharded_10:%.*]] = shard.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32> + %arg2_s = shard.shard %arg2 to %s2 : tensor<2x32x8xf32> + // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[vsharded_12:%.*]] = shard.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> + // CHECK: [[v2:%.*]] = tosa.matmul + %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK: [[vsharded_13:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> + %5 = shard.shard %4 to %s0 : tensor<2x4x8xf32> + // CHECK: [[vsharded_14:%.*]] = shard.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> + %6 = shard.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32> + // CHECK: return [[vsharded_14]] + return %6 : tensor<2x4x8xf32> +} + +// CHECK-LABEL: func.func @elementwise_duplicated_chain +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V5:.*]] = shard.shard %[[V4]] to %[[S0]] : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[]] : !shard.sharding + %2 = shard.shard %1 to %s0 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V5]] + return %2 : tensor<8x16xf32> +} diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Shard/simplifications.mlir similarity index 69% rename from mlir/test/Dialect/Mesh/simplifications.mlir rename to mlir/test/Dialect/Shard/simplifications.mlir index e955f4c134259..33cd490be744a 100644 --- a/mlir/test/Dialect/Mesh/simplifications.mlir +++ b/mlir/test/Dialect/Shard/simplifications.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s +// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s -mesh.mesh @mesh0(shape = 4x2) -mesh.mesh @mesh1(shape = 4) +shard.grid @grid0(shape = 4x2) +shard.grid @grid1(shape = 4) // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to // `all_reduce(x + y)`. @@ -11,13 +11,13 @@ func.func @all_reduce_arith_addf_endomorphism( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] %2 = arith.addf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xf32> } @@ -28,13 +28,13 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] %2 = arith.addf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]] return %2, %2 : tensor<5xf32>, tensor<5xf32> } @@ -46,11 +46,11 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -58,17 +58,17 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result return %0, %2 : tensor<5xf32>, tensor<5xf32> } -// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh -func.func @all_reduce_arith_addf_no_endomorphism_different_mesh( +// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid +func.func @all_reduce_arith_addf_no_endomorphism_different_grid( // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1 - %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1 + %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -76,17 +76,17 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_mesh( return %2 : tensor<5xf32> } -// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes -func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes( +// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes +func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes( // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -100,11 +100,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -118,11 +118,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_elemen %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf64> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf64> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf64> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf64> @@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]] %2 = arith.minimumf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xf32> } @@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism( %arg0: tensor<5xi32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32> %arg1: tensor<5xi32>) -> tensor<5xi32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min : tensor<5xi32> -> tensor<5xi32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min : tensor<5xi32> -> tensor<5xi32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]] %2 = arith.minsi %0, %1 : tensor<5xi32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xi32> } diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir deleted file mode 100644 index 8598d81ff6cfa..0000000000000 --- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_1d_4(shape = 4) - -// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets -func.func @tensor_empty_static_sharded_dims_offsets() -> () { - %b = tensor.empty() : tensor<8x16xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<8x16xf32> - // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]] - // CHECK-SAME: ] : index, index - // CHECK: tensor.empty(%[[V0]]#0) : tensor - - return -} - -// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets -// CHECK-SAME: %[[A0:.*]]: index -func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { - %b = tensor.empty(%arg0) : tensor<8x?xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<8x?xf32> - // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]] - // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]] - // CHECK-SAME: ] : index, index - // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor - - return -} - -// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes -func.func @tensor_empty_same_static_dims_sizes() -> () { - %b = tensor.empty() : tensor<16x16xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<16x16xf32> - // CHECK-NEXT: tensor.empty() : tensor<4x16xf32> - - return -} - -// CHECK-LABEL: func @tensor_empty_0d -func.func @tensor_empty_0d() -> () { - tensor.empty() : tensor - // CHECK-NEXT: tensor.empty() : tensor - return -} diff --git a/mlir/test/Dialect/Tensor/shard-partition.mlir b/mlir/test/Dialect/Tensor/shard-partition.mlir new file mode 100644 index 0000000000000..5918ee1eddf57 --- /dev/null +++ b/mlir/test/Dialect/Tensor/shard-partition.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_1d_4(shape = 4) + +// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets +func.func @tensor_empty_static_sharded_dims_offsets() -> () { + %b = tensor.empty() : tensor<8x16xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<8x16xf32> + // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index + // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index + // CHECK: tensor.empty(%[[V0]]#0) : tensor + + return +} + +// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets +// CHECK-SAME: %[[A0:.*]]: index +func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { + %b = tensor.empty(%arg0) : tensor<8x?xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<8x?xf32> + // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index + // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, %[[A0]] + // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index + // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor + + return +} + +// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes +func.func @tensor_empty_same_static_dims_sizes() -> () { + %b = tensor.empty() : tensor<16x16xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<16x16xf32> + // CHECK-NEXT: tensor.empty() : tensor<4x16xf32> + + return +} + +// CHECK-LABEL: func @tensor_empty_0d +func.func @tensor_empty_0d() -> () { + tensor.empty() : tensor + // CHECK-NEXT: tensor.empty() : tensor + return +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index eb2f74e8aeca1..3b7bd9b9637a8 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVM) add_subdirectory(Math) add_subdirectory(MemRef) -add_subdirectory(Mesh) +add_subdirectory(Shard) add_subdirectory(NVGPU) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt similarity index 51% rename from mlir/test/lib/Dialect/Mesh/CMakeLists.txt rename to mlir/test/lib/Dialect/Shard/CMakeLists.txt index 7bd0493d11a7e..f91c54721e030 100644 --- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt @@ -1,14 +1,14 @@ # Exclude tests from libMLIR.so -add_mlir_library(MLIRMeshTest +add_mlir_library(MLIRShardTest TestOpLowering.cpp - TestReshardingSpmdization.cpp + TestReshardingPartition.cpp TestSimplifications.cpp EXCLUDE_FROM_LIBMLIR ) -mlir_target_link_libraries(MLIRMeshTest PUBLIC - MLIRMeshDialect - MLIRMeshTransforms +mlir_target_link_libraries(MLIRShardTest PUBLIC + MLIRShardDialect + MLIRShardTransforms MLIRPass MLIRRewrite MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp similarity index 80% rename from mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp rename to mlir/test/lib/Dialect/Shard/TestOpLowering.cpp index dbae93b380f2b..43f3b3f239181 100644 --- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp +++ b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" @@ -24,17 +24,17 @@ struct TestAllSliceOpLoweringPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); + shard::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { - mesh::registerAllSliceOpLoweringDialects(registry); + shard::registerAllSliceOpLoweringDialects(registry); } StringRef getArgument() const final { - return "test-mesh-all-slice-op-lowering"; + return "test-grid-all-slice-op-lowering"; } StringRef getDescription() const final { return "Test lowering of all-slice."; @@ -48,21 +48,21 @@ struct TestMultiIndexOpLoweringPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, - symbolTableCollection); + shard::populateProcessMultiIndexOpLoweringPatterns(patterns, + symbolTableCollection); LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { - mesh::registerProcessMultiIndexOpLoweringDialects(registry); + shard::registerProcessMultiIndexOpLoweringDialects(registry); } StringRef getArgument() const final { - return "test-mesh-process-multi-index-op-lowering"; + return "test-grid-process-multi-index-op-lowering"; } StringRef getDescription() const final { - return "Test lowering of mesh.process_multi_index op."; + return "Test lowering of shard.process_multi_index op."; } }; diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp similarity index 75% rename from mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp rename to mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp index 102e64de4bd1f..ac71ff60fc509 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Spmdization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Partition.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -22,11 +22,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { -struct TestMeshReshardingRewritePattern : OpRewritePattern { +struct TestReshardingRewritePattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShardOp op, @@ -36,18 +36,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern { } SymbolTableCollection symbolTable; - mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom( - op, cast(op.getSharding().getDefiningOp()).getMeshAttr()); + shard::GridOp grid = symbolTable.lookupNearestSymbolFrom( + op, cast(op.getSharding().getDefiningOp()).getGridAttr()); bool foundUser = false; for (auto user : op->getUsers()) { if (auto targetShardOp = llvm::dyn_cast(user)) { if (targetShardOp.getAnnotateForUsers() && - mesh == symbolTable.lookupNearestSymbolFrom( + grid == symbolTable.lookupNearestSymbolFrom( targetShardOp, cast( targetShardOp.getSharding().getDefiningOp()) - .getMeshAttr())) { + .getGridAttr())) { foundUser = true; break; } @@ -61,22 +61,22 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern { for (auto user : op->getUsers()) { auto targetShardOp = llvm::dyn_cast(user); if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || - symbolTable.lookupNearestSymbolFrom( + symbolTable.lookupNearestSymbolFrom( targetShardOp, cast(targetShardOp.getSharding().getDefiningOp()) - .getMeshAttr()) != mesh) { + .getGridAttr()) != grid) { continue; } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); ShapedType sourceShardShape = - shardShapedType(op.getResult().getType(), mesh, op.getSharding()); + shardShapedType(op.getResult().getType(), grid, op.getSharding()); TypedValue sourceShard = cast>( builder .create(sourceShardShape, op.getSrc()) ->getResult(0)); TypedValue targetShard = - reshard(builder, mesh, op, targetShardOp, sourceShard); + reshard(builder, grid, op, targetShardOp, sourceShard); Value newTargetUnsharded = builder .create( @@ -90,13 +90,13 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern { } }; -struct TestMeshReshardingPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass) +struct TestReshardingPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReshardingPass) void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + patterns.insert(&getContext()); if (failed(applyPatternsGreedily(getOperation().getOperation(), std::move(patterns)))) { return signalPassFailure(); @@ -107,18 +107,18 @@ struct TestMeshReshardingPass registry.insert(); } StringRef getArgument() const final { - return "test-mesh-resharding-spmdization"; + return "test-grid-resharding-partition"; } StringRef getDescription() const final { - return "Test Mesh dialect resharding spmdization."; + return "Test Shard dialect resharding partition."; } }; } // namespace namespace mlir { namespace test { -void registerTestMeshReshardingSpmdizationPass() { - PassRegistration(); +void registerTestReshardingPartitionPass() { + PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp similarity index 60% rename from mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp rename to mlir/test/lib/Dialect/Shard/TestSimplifications.cpp index 01e196d29f7a5..28852153f37f6 100644 --- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp +++ b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -16,23 +16,23 @@ using namespace mlir; namespace { -struct TestMeshSimplificationsPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass) +struct TestShardSimplificationsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass) void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } - StringRef getArgument() const final { return "test-mesh-simplifications"; } - StringRef getDescription() const final { return "Test mesh simplifications"; } + StringRef getArgument() const final { return "test-grid-simplifications"; } + StringRef getDescription() const final { return "Test grid simplifications"; } }; } // namespace -void TestMeshSimplificationsPass::runOnOperation() { +void TestShardSimplificationsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateSimplificationPatterns(patterns, symbolTableCollection); + shard::populateSimplificationPatterns(patterns, symbolTableCollection); [[maybe_unused]] LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); assert(succeeded(status) && "Rewrite patters application did not converge."); @@ -40,8 +40,8 @@ void TestMeshSimplificationsPass::runOnOperation() { namespace mlir { namespace test { -void registerTestMeshSimplificationsPass() { - PassRegistration(); +void registerTestShardSimplificationsPass() { + PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 26d7597347a8a..6958fe3001b89 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -29,7 +29,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestMathToVCIX MLIRMemRefTestPasses MLIRTestMemRefToLLVMWithTransforms - MLIRMeshTest + MLIRShardTest MLIRNVGPUTestPasses MLIRSCFTestPasses MLIRShapeTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 143a5e8e8f8dd..2c0975302e6a5 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -130,8 +130,8 @@ void registerTestIrdlTestDialectConversionPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestMemRefToLLVMWithTransforms(); -void registerTestMeshReshardingSpmdizationPass(); -void registerTestMeshSimplificationsPass(); +void registerTestReshardingPartitionPass(); +void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); @@ -276,8 +276,8 @@ void registerTestPasses() { mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); mlir::test::registerTestMemRefToLLVMWithTransforms(); - mlir::test::registerTestMeshReshardingSpmdizationPass(); - mlir::test::registerTestMeshSimplificationsPass(); + mlir::test::registerTestReshardingPartitionPass(); + mlir::test::registerTestShardSimplificationsPass(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 826fb03273e6d..ad5699238b66c 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3057,14 +3057,14 @@ cc_library( ) ##---------------------------------------------------------------------------## -# Mesh Dialect +# Shard Dialect ##---------------------------------------------------------------------------## td_library( - name = "MeshTdFiles", + name = "ShardTdFiles", srcs = [ - "include/mlir/Dialect/Mesh/IR/MeshBase.td", - "include/mlir/Dialect/Mesh/IR/MeshOps.td", + "include/mlir/Dialect/Shard/IR/ShardBase.td", + "include/mlir/Dialect/Shard/IR/ShardOps.td", ], includes = ["include"], deps = [ @@ -3076,92 +3076,92 @@ td_library( ) gentbl_cc_library( - name = "MeshIncGen", + name = "ShardIncGen", tbl_outs = { - "include/mlir/Dialect/Mesh/IR/MeshOps.h.inc": [ + "include/mlir/Dialect/Shard/IR/ShardOps.h.inc": [ "-gen-op-decls", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshOps.cpp.inc": [ + "include/mlir/Dialect/Shard/IR/ShardOps.cpp.inc": [ "-gen-op-defs", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshDialect.h.inc": [ + "include/mlir/Dialect/Shard/IR/ShardDialect.h.inc": [ "-gen-dialect-decls", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc": [ + "include/mlir/Dialect/Shard/IR/ShardDialect.cpp.inc": [ "-gen-dialect-defs", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshEnums.h.inc": [ + "include/mlir/Dialect/Shard/IR/ShardEnums.h.inc": [ "-gen-enum-decls", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc": [ + "include/mlir/Dialect/Shard/IR/ShardEnums.cpp.inc": [ "-gen-enum-defs", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshAttributes.h.inc": [ + "include/mlir/Dialect/Shard/IR/ShardAttributes.h.inc": [ "-gen-attrdef-decls", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc": [ + "include/mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc": [ "-gen-attrdef-defs", - "-dialect=mesh", + "-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshTypes.h.inc": [ + "include/mlir/Dialect/Shard/IR/ShardTypes.h.inc": [ "-gen-typedef-decls", - "-typedefs-dialect=mesh", + "-typedefs-dialect=shard", ], - "include/mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc": [ + "include/mlir/Dialect/Shard/IR/ShardTypes.cpp.inc": [ "-gen-typedef-defs", - "-typedefs-dialect=mesh", + "-typedefs-dialect=shard", ], }, tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Mesh/IR/MeshOps.td", + td_file = "include/mlir/Dialect/Shard/IR/ShardOps.td", deps = [ - ":MeshTdFiles", + ":ShardTdFiles", ":ShapeOpsTdFiles", ], ) gentbl_cc_library( - name = "MeshShardingInterfaceIncGen", + name = "ShardingInterfaceIncGen", tbl_outs = { - "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"], - "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"], + "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"], + "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"], }, tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td", + td_file = "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td", deps = [":OpBaseTdFiles"], ) cc_library( - name = "MeshShardingInterface", - srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"], + name = "ShardingInterface", + srcs = ["lib/Dialect/Shard/Interfaces/ShardingInterface.cpp"], hdrs = [ - "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h", - "include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h", + "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h", + "include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h", ], includes = ["include"], deps = [ ":DialectUtils", ":IR", - ":MeshDialect", - ":MeshShardingInterfaceIncGen", + ":ShardDialect", + ":ShardingInterfaceIncGen", ":Support", "//llvm:Support", ], ) cc_library( - name = "MeshDialect", - srcs = ["lib/Dialect/Mesh/IR/MeshOps.cpp"], + name = "ShardDialect", + srcs = ["lib/Dialect/Shard/IR/ShardOps.cpp"], hdrs = [ - "include/mlir/Dialect/Mesh/IR/MeshDialect.h", - "include/mlir/Dialect/Mesh/IR/MeshOps.h", + "include/mlir/Dialect/Shard/IR/ShardDialect.h", + "include/mlir/Dialect/Shard/IR/ShardOps.h", ], includes = ["include"], deps = [ @@ -3172,7 +3172,7 @@ cc_library( ":IR", ":InferTypeOpInterface", ":InliningUtils", - ":MeshIncGen", + ":ShardIncGen", ":SideEffectInterfaces", ":Support", ":ViewLikeInterface", @@ -3181,23 +3181,23 @@ cc_library( ) gentbl_cc_library( - name = "MeshTransformsPassIncGen", - tbl_outs = {"include/mlir/Dialect/Mesh/Transforms/Passes.h.inc": [ + name = "ShardTransformsPassIncGen", + tbl_outs = {"include/mlir/Dialect/Shard/Transforms/Passes.h.inc": [ "-gen-pass-decls", - "-name=Mesh", + "-name=Shard", ]}, tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Mesh/Transforms/Passes.td", + td_file = "include/mlir/Dialect/Shard/Transforms/Passes.td", deps = [":PassBaseTdFiles"], ) cc_library( - name = "MeshTransforms", + name = "ShardTransforms", srcs = glob([ - "lib/Dialect/Mesh/Transforms/*.cpp", - "lib/Dialect/Mesh/Transforms/*.h", + "lib/Dialect/Shard/Transforms/*.cpp", + "lib/Dialect/Shard/Transforms/*.h", ]), - hdrs = glob(["include/mlir/Dialect/Mesh/Transforms/*.h"]), + hdrs = glob(["include/mlir/Dialect/Shard/Transforms/*.h"]), includes = ["include"], deps = [ ":AffineDialect", @@ -3210,9 +3210,9 @@ cc_library( ":FuncDialect", ":FunctionInterfaces", ":IR", - ":MeshDialect", - ":MeshShardingInterface", - ":MeshTransformsPassIncGen", + ":ShardDialect", + ":ShardingInterface", + ":ShardTransformsPassIncGen", ":Pass", ":Support", ":TensorDialect", @@ -3222,11 +3222,11 @@ cc_library( ) cc_library( - name = "MeshToMPIConversion", + name = "ShardToMPIConversion", srcs = glob([ - "lib/Conversion/MeshToMPI/*.cpp", + "lib/Conversion/ShardToMPI/*.cpp", ]), - hdrs = glob(["include/mlir/Conversion/MeshToMPI/*.h"]), + hdrs = glob(["include/mlir/Conversion/ShardToMPI/*.h"]), includes = ["include"], deps = [ ":AffineDialect", @@ -3241,8 +3241,8 @@ cc_library( ":LinalgDialect", ":MPIDialect", ":MemRefDialect", - ":MeshDialect", - ":MeshTransforms", + ":ShardDialect", + ":ShardTransforms", ":Pass", ":SCFDialect", ":Support", @@ -3989,7 +3989,7 @@ cc_library( ":MemRefToEmitC", ":MemRefToLLVM", ":MemRefToSPIRV", - ":MeshToMPIConversion", + ":ShardToMPIConversion", ":NVGPUToNVVM", ":NVVMToLLVM", ":OpenACCToSCF", @@ -4523,7 +4523,7 @@ cc_library( ":FuncDialect", ":IR", ":InliningUtils", - ":MeshShardingInterface", + ":ShardingInterface", ], ) @@ -4622,7 +4622,7 @@ cc_library( ":MemRefToEmitC", ":MemRefToLLVM", ":MemRefTransformOps", - ":MeshDialect", + ":ShardDialect", ":NVGPUTransformOps", ":NVVMTarget", ":NVVMToLLVM", @@ -7195,7 +7195,7 @@ cc_library( includes = ["include"], deps = [ ":IR", - ":MeshShardingInterface", + ":ShardingInterface", ":TensorDialect", "//llvm:Support", ], @@ -9020,8 +9020,8 @@ cc_library( ":MemRefToSPIRV", ":MemRefTransformOps", ":MemRefTransforms", - ":MeshDialect", - ":MeshTransforms", + ":ShardDialect", + ":ShardTransforms", ":NVGPUDialect", ":NVGPUPassIncGen", ":NVGPUToNVVM", @@ -9121,7 +9121,7 @@ cc_binary( "//mlir/test:TestMath", "//mlir/test:TestMathToVCIX", "//mlir/test:TestMemRef", - "//mlir/test:TestMesh", + "//mlir/test:TestShard", "//mlir/test:TestNVGPU", "//mlir/test:TestPDLL", "//mlir/test:TestPass", @@ -9183,7 +9183,7 @@ cc_binary( "//mlir/test:TestMathToVCIX", "//mlir/test:TestMemRef", "//mlir/test:TestMemRefToLLVMWithTransforms", - "//mlir/test:TestMesh", + "//mlir/test:TestShard", "//mlir/test:TestNVGPU", "//mlir/test:TestPDLL", "//mlir/test:TestPass", @@ -10549,7 +10549,7 @@ cc_library( ":LinalgStructuredOpsIncGen", ":MathDialect", ":MemRefDialect", - ":MeshShardingInterface", + ":ShardingInterface", ":Parser", ":SCFDialect", ":SideEffectInterfaces", @@ -10700,9 +10700,9 @@ cc_library( ":MathDialect", ":MemRefDialect", ":MemRefTransforms", - ":MeshDialect", - ":MeshShardingInterface", - ":MeshTransforms", + ":ShardDialect", + ":ShardingInterface", + ":ShardTransforms", ":Pass", ":RuntimeVerifiableOpInterface", ":SCFDialect", @@ -11199,8 +11199,8 @@ cc_library( ":InferTypeOpInterface", ":InliningUtils", ":LoopLikeInterface", - ":MeshDialect", - ":MeshShardingInterface", + ":ShardDialect", + ":ShardingInterface", ":Pass", ":QuantOps", ":SideEffectInterfaces", @@ -12142,7 +12142,7 @@ cc_library( ":FuncTransforms", ":IR", ":MemRefDialect", - ":MeshShardingInterface", + ":ShardingInterface", ":Pass", ":SideEffectInterfaces", ":TensorDialect", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 95e3ee4df7bc5..e7770fcc9eabd 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -903,8 +903,8 @@ cc_library( ) cc_library( - name = "TestMesh", - srcs = glob(["lib/Dialect/Mesh/**/*.cpp"]), + name = "TestShard", + srcs = glob(["lib/Dialect/Shard/**/*.cpp"]), includes = ["lib/Dialect/Test"], deps = [ ":TestDialect", @@ -912,8 +912,8 @@ cc_library( "//mlir:DialectUtils", "//mlir:FuncDialect", "//mlir:IR", - "//mlir:MeshDialect", - "//mlir:MeshTransforms", + "//mlir:ShardDialect", + "//mlir:ShardTransforms", "//mlir:Pass", "//mlir:SPIRVDialect", "//mlir:Support",