Skip to content

Commit 676b4bc

Browse files
fschlimbgithub-actions[bot]
authored andcommitted
Automerge: [NFC][mlir][mesh,shard] Fixing misnomers in mesh dialect, renaming 'mesh' dialect to 'shard' (#150177)
Dialect to 'shard' (discourse 87053) - dialect name mesh -> shard - (device) mesh -> (device) grid - spmdize -> partition A lot of diffs, but simple renames only. @tkarna @yaochengji
2 parents 0b16dc8 + b2d4963 commit 676b4bc

File tree

103 files changed

+3617
-3620
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+3617
-3620
lines changed

mlir/docs/Dialects/Mesh.md

Lines changed: 0 additions & 74 deletions
This file was deleted.

mlir/docs/Dialects/Shard.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# 'shard' Dialect
2+
3+
The 'shard' dialect defines a set of attributes, operations, and interfaces for
4+
working with tensor sharding and device communication.
5+
6+
It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*).
7+
8+
Originally, the dialect was called `mesh`, but it was renamed to better reflect
9+
what it actually does.
10+
11+
[TOC]
12+
13+
## Collective Communication Operations
14+
15+
The 'shard' dialect includes several collective operations that help coordinate
16+
communication between devices arranged in a grid.
17+
18+
If you’re not already familiar with collective operations, [this Wikipedia
19+
article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting
20+
point.
21+
22+
Unlike traditional collectives that are defined in terms of message-passing
23+
between explicit buffers on each process, the collectives in this dialect work
24+
at a higher level. They’re defined in terms of how data moves across the
25+
dimensions of a tensor, and the participating processes are inferred from how
26+
the tensor is sharded - not specified manually.
27+
28+
### Device Groups
29+
30+
Each collective operation runs within a group of devices. You define groups
31+
using the `grid` and `grid_axes` attributes, which describe how to slice the
32+
full device grid into smaller groups.
33+
34+
Devices that have the same coordinates *outside* the listed `grid_axes` belong
35+
to the same group.
36+
37+
Example: Say your device grid is shaped `2×3×4×5`, and you set
38+
`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:
39+
40+
```
41+
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
42+
```
43+
44+
So the groups are identified by the coordinates `(k, m)`, and devices like
45+
`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)`
46+
is in a different group.
47+
48+
For some collectives (like `all-to-all`), the order of devices in the group
49+
matters. The device order is based on the order of axes in `grid_axes`, from
50+
outermost to innermost.
51+
52+
Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before
53+
`(i, 0, k, 1)` and `(i, 2, k, 0)`.
54+
55+
### In-group Devices
56+
57+
Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific
58+
device within each group. These in-group devices are identified using their
59+
coordinates over the axes listed in `grid_axes`.
60+
61+
Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified
62+
as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full
63+
device index would be `(i, g, j)`.
64+
65+
### Purity and Execution Model
66+
67+
Collective operations involve all devices in a group (e.g. `all-gather`,
68+
`all-to-all`) and are considered pure. Operations like `send` and `recv` are not
69+
collective and are not pure.
70+
71+
The execution model assumes SPMD (Single Program, Multiple Data):
72+
73+
* Every process runs the same program.
74+
* At any collective operation, all processes are in sync.
75+
76+
This means compiler optimizations must treat collective ops carefully. For
77+
example, if a collective is removed during optimization, it must be removed from
78+
*every* path and *every* process that would have participated - otherwise, you’ll
79+
get undefined behavior at runtime.
80+
81+
Marking these ops as pure also helps with standard compiler passes like dead
82+
code elimination and common subexpression elimination. It ensures that when the
83+
program is executed, all devices hit the same line of code at the same time
84+
during collectives and so avoid dead-locks.
85+
86+
## Operations
87+
88+
[include "Dialects/ShardOps.md"]
89+
90+
## Attributes
91+
92+
[include "Dialects/ShardAttrs.md"]

mlir/docs/Passes.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ This document describes the available MLIR passes and their contracts.
7272

7373
[include "MemRefPasses.md"]
7474

75-
## 'mesh' Dialect Passes
75+
## 'shard' Dialect Passes
7676

77-
[include "MeshPasses.md"]
77+
[include "ShardPasses.md"]
7878

7979
## 'ml\_program' Dialect Passes
8080

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5353
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
5454
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
55-
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
5655
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
5756
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5857
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
@@ -66,6 +65,7 @@
6665
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
6766
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
6867
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
68+
#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
6969
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
7070
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
7171
#include "mlir/Conversion/TosaToArith/TosaToArith.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -903,13 +903,13 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> {
903903
}
904904

905905
//===----------------------------------------------------------------------===//
906-
// MeshToMPI
906+
// ShardToMPI
907907
//===----------------------------------------------------------------------===//
908908

909-
def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
910-
let summary = "Convert Mesh dialect to MPI dialect.";
909+
def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
910+
let summary = "Convert Shard dialect to MPI dialect.";
911911
let description = [{
912-
This pass converts communication operations from the Mesh dialect to the
912+
This pass converts communication operations from the Shard dialect to the
913913
MPI dialect.
914914
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
915915
use that integer value instead of calling MPI_Comm_rank. This allows
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
1+
//===- ShardToMPI.h - Convert Shard to MPI dialect --------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
10-
#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
9+
#ifndef MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
10+
#define MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
1111

1212
#include "mlir/Pass/Pass.h"
1313
#include "mlir/Support/LLVM.h"
1414

1515
namespace mlir {
1616
class Pass;
1717

18-
#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
18+
#define GEN_PASS_DECL_CONVERTSHARDTOMPIPASS
1919
#include "mlir/Conversion/Passes.h.inc"
2020

2121
} // namespace mlir
2222

23-
#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
23+
#endif // MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
1919
add_subdirectory(LLVMIR)
2020
add_subdirectory(Math)
2121
add_subdirectory(MemRef)
22-
add_subdirectory(Mesh)
22+
add_subdirectory(Shard)
2323
add_subdirectory(MLProgram)
2424
add_subdirectory(MPI)
2525
add_subdirectory(NVGPU)

mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h renamed to mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MeshShardingExtensions.h - -----------------------------------------===//
1+
//===- ShardingExtensions.h - -----------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
1+
//===- ShardingInterfaceImpl.h ----------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
10-
#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
9+
#ifndef MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
10+
#define MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
1111

1212
namespace mlir {
1313
class DialectRegistry;
1414

1515
namespace linalg {
16-
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
16+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
1717
} // namespace linalg
1818
} // namespace mlir
1919

20-
#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
20+
#endif // MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H

mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)