|
1 | 1 | # 'shard' Dialect |
2 | 2 |
|
3 | | -This dialect contains a set of attributes, operations and interfaces that |
4 | | -are useful for representing sharding of tensors and communication between |
5 | | -devices. |
| 3 | +The 'shard' dialect defines a set of attributes, operations, and interfaces for |
| 4 | +working with tensor sharding and device communication. |
6 | 5 |
|
7 | | -The Shard dialect was inspired by GSPMD (GSPMD: General and Scalable |
8 | | -Parallelization for ML Computation Graphs). |
| 6 | +It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*). |
9 | 7 |
|
10 | | -It was originally introduced under the name 'mesh' but was later renamed |
11 | | -to better reflect its purpose. |
| 8 | +Originally, the dialect was called `mesh`, but it was renamed to better reflect |
| 9 | +what it actually does. |
12 | 10 |
|
13 | 11 | [TOC] |
14 | 12 |
|
15 | 13 | ## Collective Communication Operations |
16 | | -There are a number of operations in the Shard dialect to facilitate |
17 | | -communication between devices in a grid. |
18 | | -It is assumed that the user is familiar with collective operations. |
19 | | -[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good |
20 | | -explanation. |
21 | | -The main addition is that the collectives in this dialect have grid |
22 | | -semantics. |
23 | | - |
24 | | -### Device groups |
25 | | -The operation attributes `grid` and `grid_axes` specifies a list of device grid |
26 | | -axes that partition the devices into disjoint groups. |
27 | | -The collective operation is performed between devices in the same group. |
28 | | -Devices that have the same coordinates outside of axes `grid_axes` are in the |
29 | | -same group. |
30 | | -A group is described by its multi-index along the axes outside of `grid_axes`. |
31 | | -For example if we have a device grid of size `2x3x4x5` and the partition grid |
32 | | -axes list is `[0, 1]` then devices are partitioned into the groups |
33 | | -`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`. |
34 | | -The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`. |
35 | | -Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. |
36 | | -Device (1, 0, 2, 4) will be in another group. |
37 | | -Some collective operations like all-to-all and all-gather care about the |
38 | | -order of devices. |
39 | | -The order of device in a device group is induced by the order of axes in |
40 | | -`grid_axes`. |
41 | | -The axes are ordered from outer to inner. |
42 | | -If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede |
43 | | -both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`. |
44 | | - |
45 | | -### In-group Device |
46 | | -Some operations like `broadcast`, `scatter` and `send` specify devices in each |
47 | | -device-group. |
48 | | -These devices are represented with their multi-index over the grid axes that |
49 | | -are not constant within a device group. |
50 | | -These are the axes specified by `grid_axes` attribute. |
51 | | - |
52 | | -For Example on a 3D grid an operation with `grid_axes = [0, 2]` would specify |
53 | | -an in-group device with `(i, j)`. Then for each group with index `g` on the |
54 | | -second axis, the in-group device would be `(i, g, j)`. |
55 | | -### Purity |
56 | | -Collectives that involve the whole device group to perform a single operation |
57 | | -are pure. The exceptions are `send` and `recv`. |
58 | | - |
59 | | -There is an assumption that the execution is SPMD. |
60 | | -Not only that each process runs the same program, but that at the point of |
61 | | -execution of a collective operation, all processes are in a coherent state. |
62 | | -All compiler transformations must be consistent. |
63 | | -Collective operations in the IR that may correspond to the same runtime |
64 | | -collective operation must be transformed in a consistent manner. |
65 | | -For example if a collective operation is optimized out, than it must also |
66 | | -not appear in any path of execution on any process. |
67 | | - |
68 | | -Having the operations as `Pure` implies that if an interpreter is to execute |
69 | | -the IR containing the `grid` collectives, all processes would execute the same |
70 | | -line when they reach a pure collective operation. |
71 | | -This requirement stems from the need to be compatible with general optimization |
72 | | -passes like dead code and common sub-expression elimination. |
| 14 | + |
| 15 | +The 'shard' dialect includes several collective operations that help coordinate |
| 16 | +communication between devices arranged in a grid. |
| 17 | + |
| 18 | +If you’re not already familiar with collective operations, [this Wikipedia |
| 19 | +article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting |
| 20 | +point. |
| 21 | + |
| 22 | +Unlike traditional collectives that are defined in terms of message-passing |
| 23 | +between explicit buffers on each process, the collectives in this dialect work |
| 24 | +at a higher level. They’re defined in terms of how data moves across the |
| 25 | +dimensions of a tensor, and the participating processes are inferred from how |
| 26 | +the tensor is sharded - not specified manually. |
| 27 | + |
| 28 | +### Device Groups |
| 29 | + |
| 30 | +Each collective operation runs within a group of devices. You define groups |
| 31 | +using the `grid` and `grid_axes` attributes, which describe how to slice the |
| 32 | +full device grid into smaller groups. |
| 33 | + |
| 34 | +Devices that have the same coordinates *outside* the listed `grid_axes` belong |
| 35 | +to the same group. |
| 36 | + |
| 37 | +Example: Say your device grid is shaped `2×3×4×5`, and you set |
| 38 | +`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like: |
| 39 | + |
| 40 | +``` |
| 41 | +{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 } |
| 42 | +``` |
| 43 | + |
| 44 | +So the groups are identified by the coordinates `(k, m)`, and devices like |
| 45 | +`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)` |
| 46 | +is in a different group. |
| 47 | + |
| 48 | +For some collectives (like `all-to-all`), the order of devices in the group |
| 49 | +matters. The device order is based on the order of axes in `grid_axes`, from |
| 50 | +outermost to innermost. |
| 51 | + |
| 52 | +Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before |
| 53 | +`(i, 0, k, 1)` and `(i, 2, k, 0)`. |
| 54 | + |
| 55 | +### In-group Devices |
| 56 | + |
| 57 | +Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific |
| 58 | +device within each group. These in-group devices are identified using their |
| 59 | +coordinates over the axes listed in `grid_axes`. |
| 60 | + |
| 61 | +Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified |
| 62 | +as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full |
| 63 | +device index would be `(i, g, j)`. |
| 64 | + |
| 65 | +### Purity and Execution Model |
| 66 | + |
| 67 | +Collective operations involve all devices in a group (e.g. `all-gather`, |
| 68 | +`all-to-all`) and are considered pure. Operations like `send` and `recv` are not |
| 69 | +collective and are not pure. |
| 70 | + |
| 71 | +The execution model assumes SPMD (Single Program, Multiple Data): |
| 72 | + |
| 73 | +* Every process runs the same program. |
| 74 | +* At any collective operation, all processes are in sync. |
| 75 | + |
| 76 | +This means compiler optimizations must treat collective ops carefully. For |
| 77 | +example, if a collective is removed during optimization, it must be removed from |
| 78 | +*every* path and *every* process that would have participated - otherwise, you’ll |
| 79 | +get undefined behavior at runtime. |
| 80 | + |
| 81 | +Marking these ops as pure also helps with standard compiler passes like dead |
| 82 | +code elimination and common subexpression elimination. It ensures that when the |
| 83 | +program is executed, all devices hit the same line of code at the same time |
| 84 | +during collectives and so avoid dead-locks. |
73 | 85 |
|
74 | 86 | ## Operations |
75 | 87 |
|
|
0 commit comments