Skip to content

Commit 83d215b

Browse files
melissawmcopybara-github
authored andcommitted
PR #263: Data flow edge ops documentation
Imported from GitHub PR #263 Here's a follow-up to #175 . This PR fixes some broken links and typos/grammar/styling from the first PR, and also adds a section about Data flow edge ops to the Propagation page. I didn't feel like it would be much use to repeat what is already in the autogenerated docs, but moved one example from the reference to the explanation page to keep the reference docs succint. Copybara import of the project: -- d6aa27b by Melissa Weber Mendonça <[email protected]>: Data flow edge ops documentation -- b7fb8e1 by Melissa Weber Mendonça <[email protected]>: Updates to sharding representation doc -- 796f7ca by Melissa Weber Mendonça <[email protected]>: Fix broken links -- a2e6f37 by Melissa Weber Mendonça <[email protected]>: Address review comments -- 5430d58 by Melissa Weber Mendonça <[email protected]>: Add example back to ops.td -- 6725dbc by Melissa Weber Mendonça <[email protected]>: Add suggestions from review Merging this change closes #263 COPYBARA_INTEGRATE_REVIEW=#263 from melissawm:data-flow-docs 6725dbc PiperOrigin-RevId: 707180112
1 parent 1651459 commit 83d215b

File tree

4 files changed

+99
-36
lines changed

4 files changed

+99
-36
lines changed

docs/getting_started_jax.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@
609609
"\n",
610610
"#### What are split Axes in Shardy, aka \"x\":(2)2?\n",
611611
"\n",
612-
"Refer to \"Axis splitting and sub-axes\" in [Axis splitting and sub-axes](https://github.com/openxla/shardy/tree/main/docs/sharding_representation.md#axis-splitting-and-sub-axes)\n"
612+
"Refer to [Axis splitting and sub-axes](sharding_representation.md#axis_splitting_and_sub-axes).\n"
613613
]
614614
},
615615
{

docs/propagation.md

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ We compose multiple conflict resolution strategies in a hierarchy:
3838
and ignore all others. We also make sure that propagation won't override
3939
user defined shardings with lower priority (`>i`), even if they are ignored
4040
during previous iterations.
41-
2. **Operation based priorities**. We propagate shardings, based on the
41+
2. **Operation based priorities**. We propagate shardings based on the
4242
operation type. The "pass-through" operations (e.g., element-wise operations
4343
and reshape) have the highest priority, while operations with shape
4444
transformation (e.g., dot and reduce) have lower priority.
@@ -61,18 +61,18 @@ user priority, a full op-priority propagation is applied.
6161

6262
The sharding rule introduces an abstraction of every operation that provides the
6363
actual propagation algorithm with the information it needs to propagate
64-
shardings from operands to results or across operands, etc., without having to
65-
reason about specific operation types and their attributes. This is essentially
64+
shardings from operands to results or across operands without having to reason
65+
about specific operation types and their attributes. This is essentially
6666
factoring out the op-specific logic and providing a shared representation (data
6767
structure) for all ops for the purpose of propagation only. In its simplest
6868
form, it just provides this function:
6969

70-
```c
70+
```c++
7171
GetOpShardingRule(Operation *) -> OpShardingRuleAttr
7272
```
7373
7474
The rule allows us to write the propagation algorithm only once in a generic way
75-
that is based on this data structure (OpShardingRule), instead of replicating
75+
that is based on this data structure (`OpShardingRule`), instead of replicating
7676
similar pieces of code across many ops, vastly reducing the possibility for bugs
7777
or inconsistent behavior across ops.
7878
@@ -101,11 +101,11 @@ factor. However, it is not enough for reshapes.
101101
102102
The following reshape merges two dimensions into one:
103103
104-
```
104+
```mlir
105105
%out = mhlo.reshape(%in) : (tensor<2x4x32xf32>) -> tensor<8x32xf32>
106106
```
107107

108-
Here both dimensions 0 and 1 of the input correspond to dimension 0 of the
108+
Here, both dimensions 0 and 1 of the input correspond to dimension 0 of the
109109
output. Say we start by giving factors to the input:
110110

111111
```
@@ -121,18 +121,25 @@ need a single dimension to reference multiple factors:
121121

122122
The same can be done if the reshape were to split a dimension:
123123

124+
```mlir
125+
%out = mhlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32>
126+
```
127+
128+
Here,
129+
124130
```
125-
%out = mhlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32> ((ij), k) -> (i,j,k) : i=2, j=4, k=32
131+
((ij), k) -> (i,j,k) : i=2, j=4, k=32
126132
```
127133

128134
The dimension of size 8 here is essentially composed of the factors 2 and 4,
129-
which is why we are calling the factors (i,j,k) factors.
135+
which is why we are calling the factors `(i,j,k)` factors.
130136

131137
These factors can also work with cases where there is no full dimension that
132138
corresponds to one of the factors:
133139

134-
```
135-
%out = mhlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32> ((ij), k) -> (i,(jk)) : i=2, j=4, k=4
140+
```mlir
141+
%out = mhlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32>
142+
// ((ij), k) -> (i,(jk)) : i=2, j=4, k=4
136143
```
137144

138145
This example also emphasizes why we need to store the factor sizes - since we
@@ -146,16 +153,16 @@ In Shardy, we have the hierarchy of tensors, dimensions, and factors. They
146153
represent data at different levels. A factor is a sub-dimension. It is an
147154
internal hierarchy used in sharding propagation. Each dimension may correspond
148155
to one or more factors. The mapping between dimension and factor is defined by
149-
OpShardingRule.
156+
`OpShardingRule`.
150157

151158
![Schema showing the Shardy propagation algorithm.](images/propagation_algorithm.png)
152159

153160
**Shardy propagates sharding axes along factors instead of dimensions**. To do
154-
that, we have three steps as shown in the figure below
161+
that, we have three steps as shown in the figure below:
155162

156-
1. Project DimSharding to FactorSharding
157-
2. Propagate sharding axes in the space of FactorSharding
158-
3. Project the updated FactorSharding to get the updated DimSharding
163+
1. Project `DimSharding` to `FactorSharding`
164+
2. Propagate sharding axes in the space of `FactorSharding`
165+
3. Project the updated `FactorSharding` to get the updated `DimSharding`
159166

160167
![Schema showing sharding propagation across FactorSharding and DimSharding.](images/projected_sharding.png)
161168

@@ -210,3 +217,60 @@ along F0, propagate `["c"]` along F1, and propagate nothing along F2.
210217
T0 | "a", **"b"** | **"c"** | "f" |
211218
T1 | "a", "b" | "c", "d" | "g" |
212219
T2 | **"a", "b"** | "c", "e" | |
220+
221+
### Data flow ops
222+
223+
The above propagation step description applies to most ops. However, there are
224+
cases where a sharding rule is not appropriate. For those cases, Shardy defines
225+
*data flow* ops.
226+
227+
A data flow edge of some op X defines a bridge between a set of *sources* and a
228+
set of *targets*, such that all sources and targets should be sharded in the
229+
same way. Examples of such ops are `stablehlo::OptimizationBarrierOp`,
230+
`stablehlo::WhileOp`, `stablehlo::CaseOp` and also
231+
[`sdy::ManualComputationOp`](./sdy_dialect#sdymanual_computation_sdymanualcomputationop).
232+
Ultimately, any op that implements
233+
[ShardableDataFlowOpInterface](sdy_op_interfaces#shardabledataflowopinterface_shardabledataflowopinterface)
234+
is considered a data flow op.
235+
236+
An op can have multiple data flow edges that are orthogonal to one another. For
237+
example:
238+
239+
```mlir
240+
y_0, ..., y_n = while (x_0, ..., x_n)
241+
((pred_arg_0,... , pred_arg_n) { ... })
242+
((body_arg_0,..., body_arg_n) {
243+
...
244+
return return_value_0, ..., return_value_n
245+
})
246+
```
247+
248+
This while op has `n` data flow edges: the i-th data flow edges is between
249+
sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`, `body_arg_i`.
250+
251+
Shardy will propagate shardings between all sources and targets of a data flow
252+
edge as if it was a regular op with the sources as operands and targets as
253+
results, and an identity `sdy.op_sharding_rule`. That means that forward
254+
propagation is from sources to targets and backwards propagation is from targets
255+
to sources.
256+
257+
Several methods must be implemented by the user describing how to get the
258+
sources and targets of each data flow edge through their *owner*, and also how
259+
to get and set the shardings of edge *owners*. An owner is a user-specified
260+
target of the data flow edge used by Shardy's propagation. The user can choose
261+
it arbitrarily but it needs to be static.
262+
263+
For example, given the `custom_op` defined below:
264+
265+
```c
266+
y_1, ..., y_n = custom_op (x_1, ..., x_n)
267+
((body_arg_1,..., body_arg_n) {
268+
...
269+
return return_value_1, ..., return_value_n
270+
})
271+
```
272+
273+
This custom_op has two types for data flow edges: `n` edges each between
274+
`return_value_i` (sources) and `y_i` (targets) and `n` edges between `x_i`
275+
(sources) and `body_arg_i` (targets). In this case, the edge owners are the same
276+
as the targets.

docs/sharding_representation.md

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ names and sizes.
2020

2121
The proposed sharding representation is bound to a specific logical mesh by its
2222
name, and can only reference axis names from that mesh. The sharding of a tensor
23-
specifies along which axes (of a specific logical mesh), each dimension of the
23+
specifies along which axes (of a specific logical mesh) each dimension of the
2424
tensor is sharded, ordered from major to minor. The tensor is replicated along
2525
all other axes of the mesh.
2626

@@ -47,7 +47,7 @@ We can then shard the following rank 2 tensor `[[a, b], [c, d]]` as follows:
4747
that are not used to shard a dimension are implicitly replicated, but the
4848
sharding can specify axes that are explicitly replicated and therefore
4949
cannot be used to shard a dimension later on.
50-
* [**Axis splitting and sub-axes**](#axis-splitting-and-sub-axes) - a (full)
50+
* [**Axis splitting and sub-axes**](#axis_splitting_and_sub-axes) - a (full)
5151
mesh axis can be split into multiple sub-axes that can be individually used
5252
to shard a dimension or be explicitly replicated.
5353
* [**Multiple logical meshes**](#multiple-logical-meshes) - different
@@ -68,7 +68,7 @@ We expand the basic structure and each key component in this section.
6868
### Basic structure
6969

7070
The dimension shardings tell us for each dimension of the tensor, along which
71-
axes (or [sub-axes](#axis-splitting-and-sub-axes)) it is sharded from major to
71+
axes (or [sub-axes](#axis_splitting_and_sub-axes)) it is sharded from major to
7272
minor. All other axes that don't shard a dimension are implicitly replicated (or
7373
[explicitly replicated](#explicitly-replicated-axes)).
7474

@@ -100,9 +100,7 @@ Each dimension of a tensor can either be open or closed.
100100
An open dimension is open for propagation to further shard it along additional
101101
axes, i.e. the specified dimension sharding doesn't have to be the final
102102
sharding of that dimension. This is similar (but not exactly the same as) to
103-
104-
* [`jax.sharding.PartitionSpec.UNCONSTRAINED`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec)
105-
* GSPMD's `unspecified_dims`
103+
GSPMD's `unspecified_dims`.
106104
107105
If a dimension is open we add a `?` following the axes that the dimension is
108106
already sharded on (see example below).
@@ -161,7 +159,7 @@ We can extend our example from above to have an explicitly replicated axis.
161159
// Since "y" is explicitly replicated, it can't be used to shard the 2nd
162160
// dimension that is open. However, "z" is implicitly replicated so it can be
163161
// used to shard that dimension. The local shape of this tensor (i.e. the
164-
// shape on a single device), would // be tensor<2x8xf32>.
162+
// shape on a single device), would be tensor<2x8xf32>.
165163
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>
166164
```
167165

@@ -213,13 +211,14 @@ We have a few options for dealing with such cases:
213211
* Disallow, and all-gather sub-axes that shard the input/output.
214212
215213
Currently we allow sub-axes on the inputs/outputs in the propagation pipeline.
216-
Let us know if you want a way to disable this.
214+
[Let us know](https://github.com/openxla/shardy/issues) if you want a way to
215+
disable this.
217216
218217
#### Representation
219218
220219
In the same way that we can reference specific full axes from the mesh by their
221220
name, we can reference specific sub-axes by their size and the product of all
222-
sub-axis (of the same axis name) sizes to their left (that are major to them) .
221+
sub-axis (of the same axis name) sizes to their left (that are major to them).
223222
224223
To extract a specific sub-axis of size `k` from a full axis `"x"` of size `n`,
225224
we effectively reshape the size `n` (in the mesh) into `[m, k, n/(m*k)]` and use
@@ -301,7 +300,7 @@ sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32
301300
Replicated sub-axis of the same full axis should be ordered in increasing order
302301
by their pre-size, for example:
303302
304-
```c
303+
```c++
305304
replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}
306305
```
307306

@@ -342,11 +341,9 @@ assigned to a different mesh, by naively resharding the tensor to match the
342341
destination mesh. In GSPMD this is what is usually done to resolve conflicting
343342
meshes.
344343

345-
We provide two examples below:
346-
347344
Users can specify multiple meshes with different named axes (e.g. via
348-
`jax.sharding.NamedSharding`), that have the same order of devices. In this
349-
example, `<@mesh_0, "b">` is identical to `<@mesh_1, "z">.`
345+
`jax.sharding.NamedSharding`), that have the same order of devices. Consider
346+
this example, `<@mesh_0, "b">` is identical to `<@mesh_1, "z">`:
350347

351348
```c++
352349
@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@@ -358,8 +355,8 @@ moment (different being different axis names/sizes and `device_ids`).
358355
359356
### Priorities
360357
361-
Priority is a way to prioritize certain partitioning+propagation decisions over
362-
others, and allows for incremental partitioning of a program.
358+
Priority is a way to prioritize certain partitioning and propagation decisions
359+
over others, and allows for incremental partitioning of a program.
363360
364361
Priorities are values attached to some or all dimensions of a sharding
365362
representation (replicated axes don't have priorities).
@@ -374,9 +371,10 @@ For example:
374371
```
375372

376373
Priorities give users more fine grained control over propagation, e.g., batch
377-
parallelism first, then megatron, and finally ZeRO sharding. This allows for
378-
strong guarantees about what's partitioned and allows for better debuggability
379-
by having more fine grained sharding strategies (can see how the program looks
374+
parallelism first, then [megatron](arxiv.org/abs/1909.08053), and finally
375+
[ZeRO](https://arxiv.org/abs/1910.02054) sharding. This allows for strong
376+
guarantees about what's partitioned and allows for better debuggability by
377+
having more fine grained sharding strategies (can see how the program looks
380378
after just megatron in isolation).
381379

382380
We allow attaching a priority to each dimension sharding (0 by default), which

shardy/dialect/sdy/ir/ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def Sdy_DataFlowEdgeOp : Sdy_Op<"data_flow_edge",
268268

269269
For example:
270270

271+
271272
```mlir
272273
y_0, ..., y_n = while (x_0, ..., x_n)
273274
((pred_arg_0,... , pred_arg_n) { ... })

0 commit comments

Comments
 (0)