Skip to content

Commit f038721

Browse files
authored
docs: new sharding docs (#1370)
* docs: new sharding docs * docs: more docs * docs: more docs * docs: more docs * fix: code * ci: fix env * fix: cast some ops to float * docs: cleanup
1 parent d91b736 commit f038721

File tree

6 files changed

+347
-23
lines changed

6 files changed

+347
-23
lines changed

.github/workflows/Documenter.yaml

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ name: Documentation
33
on:
44
pull_request:
55
paths:
6-
- '.github/workflows/Documenter.yaml'
7-
- 'docs/**'
8-
- 'lib/**'
9-
- 'src/**'
6+
- ".github/workflows/Documenter.yaml"
7+
- "docs/**"
8+
- "lib/**"
9+
- "src/**"
1010
push:
1111
branches:
1212
- main
13-
tags: '*'
13+
tags: "*"
1414
paths:
15-
- '.github/workflows/Documenter.yaml'
16-
- 'docs/**'
17-
- 'lib/**'
18-
- 'src/**'
15+
- ".github/workflows/Documenter.yaml"
16+
- "docs/**"
17+
- "lib/**"
18+
- "src/**"
1919

2020
concurrency:
2121
# Same group concurrency as the `PreviewCleanup.yml` workflow, because they both
@@ -37,24 +37,17 @@ jobs:
3737
- uses: actions/checkout@v4
3838
- uses: julia-actions/setup-julia@v2
3939
with:
40-
version: '1'
40+
version: "1"
4141
- uses: julia-actions/cache@v2
4242
- name: Instantiate docs environment
43+
shell: julia --color=yes --project=docs {0}
4344
run: |
44-
julia --color=yes --project=docs -e '
45-
using Pkg
46-
Pkg.instantiate()'
47-
env:
48-
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
49-
- name: Run doctests
50-
run: |
51-
julia --color=yes --project=docs -e '
52-
using Documenter: DocMeta, doctest
53-
using Reactant
54-
DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)
55-
doctest(Reactant)'
45+
using Pkg
46+
Pkg.develop(PackageSpec(path=pwd()))
47+
Pkg.instantiate()
5648
- name: Build documentation
5749
run: julia --color=yes --project=docs docs/make.jl
5850
env:
51+
XLA_FLAGS: --xla_force_host_platform_device_count=8
5952
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
6053
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
4+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
5+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
46
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
57
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
8+
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
69
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
710

811
[sources]

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ export default defineConfig({
8888
{text: "Distributed", link: "/tutorials/multihost"},
8989
{text: "Local build", link: "/tutorials/local-build"},
9090
{text: "Control Flow", link: "/tutorials/control-flow"},
91+
{text: "Sharding", link: "/tutorials/sharding"},
9192
],
9293
},
9394
{
@@ -158,6 +159,7 @@ export default defineConfig({
158159
{ text: "Distributed", link: "/tutorials/multihost" },
159160
{ text: "Local build", link: "/tutorials/local-build" },
160161
{ text: "Control Flow", link: "/tutorials/control-flow" },
162+
{ text: "Sharding", link: "/tutorials/sharding" },
161163
],
162164
}
163165
],

docs/src/tutorials/sharding.md

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# [Automatic Sharding-based Distributed Parallelism](@id sharding)
2+
3+
!!! tip "Use XLA IFRT Runtime"
4+
5+
While PJRT does support some minimal sharding capabilities on CUDA GPUs, sharding
6+
support in Reactant is primarily provided via IFRT. Before loading Reactant, set the
7+
"xla_runtime" preference to be "IFRT". This can be done with:
8+
9+
```julia
10+
using Preferences, UUIDs
11+
12+
Preferences.set_preference!(
13+
UUID("3c362404-f566-11ee-1572-e11a4b42c853"),
14+
"xla_runtime" => "IFRT"
15+
)
16+
```
17+
18+
## Basics
19+
20+
Sharding is one mechanism supported within Reactant that tries to make it easy to program
21+
for multiple devices (including [multiple nodes](@ref distributed)).
22+
23+
```@example sharding_tutorial
24+
using Reactant
25+
26+
@assert length(Reactant.devices()) > 1 # hide
27+
Reactant.devices()
28+
```
29+
30+
Sharding provides Reactant users a
31+
[PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space)
32+
programming model. Let's understand what this means through example.
33+
34+
Suppose we have a function that takes a large input array and computes sin for all elements
35+
of the array.
36+
37+
```@example sharding_tutorial
38+
function big_sin(data)
39+
data .= sin.(data)
40+
return nothing
41+
end
42+
43+
N = 1600
44+
x = Reactant.to_rarray(reshape(collect(Float32, 1:N), 40, 40))
45+
46+
compiled_big_sin = @compile big_sin(x)
47+
48+
compiled_big_sin(x)
49+
```
50+
51+
This successfully allocates the array `x` on one device, and executes it on the same device.
52+
However, suppose we want to execute this computation on multiple devices. Perhaps this is
53+
because the size of our inputs (`N`) is too large to fit on a single device. Or
54+
alternatively the function we execute is computationally expensive and we want to leverage
55+
the computing power of multiple devices.
56+
57+
Unlike more explicit communication libraries like MPI, the sharding model used by Reactant
58+
aims to let you execute a program on multiple devices without significant modifications to
59+
the single-device program. In particular, you do not need to write explicit communication
60+
calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a
61+
very large single-node and Reactant will automatically determine how to subdivide the data,
62+
computation, and required communication.
63+
64+
When using sharding, the one thing you need to change about your code is how arrays are
65+
allocated. In particular, you need to specify how the array is partitioned amongst available
66+
devices. For example, suppose you are on a machine with 4 GPUs. In the example above, we
67+
computed `sin` for all elements of a 40x40 grid. One partitioning we could select is to have
68+
it partitioned along the first axis, such that each GPU has a slice of 10x40 elements. We
69+
could accomplish this as follows. No change is required to the original function. However,
70+
the compiled function is specific to the sharding so we need to compile a new version for
71+
our sharded array.
72+
73+
```@example sharding_tutorial
74+
N = 1600
75+
76+
x_sharded_first = Reactant.to_rarray(
77+
reshape(collect(1:N), 40, 40),
78+
sharding=Sharding.NamedSharding(
79+
Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y)),
80+
(:x, nothing)
81+
)
82+
)
83+
84+
compiled_big_sin_sharded_first = @compile big_sin(x_sharded_first)
85+
86+
compiled_big_sin_sharded_first(x_sharded_first)
87+
```
88+
89+
Alternatively, we can parition the data in a different form. In particular, we could
90+
subdivide the data on both axes. As a result each GPU would have a slice of 20x20 elements.
91+
Again no change is required to the original function, but we would change the allocation as
92+
follows:
93+
94+
```@example sharding_tutorial
95+
N = 1600
96+
x_sharded_both = Reactant.to_rarray(
97+
reshape(collect(1:N), 40, 40),
98+
sharding=Sharding.NamedSharding(
99+
Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y)),
100+
(:x, :y)
101+
)
102+
)
103+
104+
compiled_big_sin_sharded_both = @compile big_sin(x_sharded_both)
105+
106+
compiled_big_sin_sharded_both(x_sharded_both)
107+
```
108+
109+
Sharding in reactant requires you to specify how the data is sharded across devices on a
110+
mesh. We start by specifying the mesh [`Sharding.Mesh`](@ref) which is a collection of the
111+
devices reshaped into an N-D grid. Additionally, we can specify names for each axis of the
112+
mesh, that are then referenced when specifying how the data is sharded.
113+
114+
1. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y))`: Creates a 2D grid of 4
115+
devices arranged in a 2x2 grid. The first axis is named `:x` and the second axis is named
116+
`:y`.
117+
2. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y))`: Creates a 2D grid of 4
118+
devices arranged in a 4x1 grid. The first axis is named `:x` and the second axis is
119+
named `:y`.
120+
121+
Given the mesh, we will specify how the data is sharded across the devices.
122+
123+
<!--
124+
TODO describe how arrays are the "global data arrays, even though data is itself only stored
125+
on relevant device and computation is performed only devices with the required data
126+
(effectively showing under the hood how execution occurs)
127+
-->
128+
129+
<!--
130+
TODO make a simple conway's game of life, or heat equation using sharding simulation example
131+
to show how a ``typical MPI'' simulation can be written using sharding.
132+
-->
133+
134+
## Simple 1-Dimensional Heat Equation
135+
136+
So far we chose a function which was perfectly parallelizable (e.g. each elemnt of the array
137+
only accesses its own data). Let's consider a more realistic example where an updated
138+
element requires data from its neighbors. In the distributed case, this requires
139+
communicating the data along the boundaries.
140+
141+
In particular, let's implement a one-dimensional
142+
[heat equation](https://en.wikipedia.org/wiki/Heat_equation) simulation. In this code you
143+
initialize the temperature of all points of the simulation and over time the code will
144+
simulate how the heat is transfered across space. In particular points of high temperature
145+
will transfer energy to points of low energy.
146+
147+
As an example, here is a visualization of a 2-dimensional heat equation:
148+
149+
![Heat Equation Animation](https://upload.wikimedia.org/wikipedia/commons/a/a9/Heat_eqn.gif)
150+
151+
<!-- TODO we should animate the above -- and even more ideally have one we generate ourselves. -->
152+
153+
To keep things simple, let's implement a 1-dimensional heat equation here. We start off with
154+
an array for the temperature at each point, and will compute the next version of the
155+
temperatures according to the equation
156+
`x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]`.
157+
158+
Let's consider how this can be implemented with explicit MPI communication. Each node will
159+
contain a subset of the total data. For example, if we simulate with 100 points, and have 4
160+
devices, each device will contain 25 data points. We're going to allocate some extra room at
161+
each end of the buffer to store the ``halo'', or the data at the boundary. Each time step
162+
that we take will first copy in the data from its neighbors into the halo via an explicit
163+
MPI send and recv call. We'll then compute the updated data for our slice of the data.
164+
165+
With sharding, things are a bit more simple. We can write the code as if we only had one
166+
device. No explicit send or recv's are necessary as they will be added automatically by
167+
Reactant when it deduces they are needed. In fact, Reactant will attempt to optimize the
168+
placement of the communicatinos to minimize total runtime. While Reactant tries to do a
169+
good job (which could be faster than an initial implementation -- especially for complex
170+
codebases), an expert may be able to find a better placement of the communication.
171+
172+
The only difference for the sharded code again occurs during allocation. Here we explicitly
173+
specify that we want to subdivide the initial grid of 100 amongst all devices. Analagously
174+
if we had 4 devices to work with, each device would have 25 elements in its local storage.
175+
From the user's standpoint, however, all arrays give access to the entire dataset.
176+
177+
::: code-group
178+
179+
```julia [MPI Based Parallelism]
180+
function one_dim_heat_equation_time_step_mpi!(data)
181+
id = MPI.Comm_rank(MPI.COMM_WORLD)
182+
last_id = MPI.Comm_size(MPI.COMM_WORLD)
183+
184+
# Send data right
185+
if id > 1
186+
MPI.Send(@view(data[end]), MPI.COMM_WORLD; dest=id + 1)
187+
end
188+
189+
# Recv data from left
190+
if id != last_id
191+
MPI.Recv(@view(data[1]), MPI.COMM_WORLD; dest=id - 1)
192+
end
193+
194+
# 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]
195+
data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end]
196+
197+
return nothing
198+
end
199+
200+
201+
# Total size of grid we want to simulate
202+
N = 100
203+
204+
# Local size of grid (total size divided by number of MPI devices)
205+
_local = N / MPI.Comm_size(MPI.COMM_WORLD)
206+
207+
# We add two to add a left side padding and right side padding, necessary for storing
208+
# boundaries from other nodes
209+
data = rand(_local + 2)
210+
211+
function simulate(data, time_steps)
212+
for i in 1:time_steps
213+
one_dim_heat_equation_time_step_mpi!(data)
214+
end
215+
end
216+
217+
simulate(data, 100)
218+
```
219+
220+
```julia [Sharded Parallelism]
221+
function one_dim_heat_equation_time_step_sharded!(data)
222+
# No send recv's required
223+
224+
# 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]
225+
# Reactant will automatically insert send and recv's
226+
data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end]
227+
228+
return nothing
229+
end
230+
231+
# Total size of grid we want to simulate
232+
N = 100
233+
234+
# Reactant's sharding handles distributing the data amongst devices, with each device
235+
# getting a corresponding fraction of the data
236+
data = Reactant.to_rarray(
237+
rand(N + 2);
238+
sharding=Sharding.NamedSharding(
239+
Sharding.Mesh(Reactant.devices(), (:x,)),
240+
(:x,)
241+
)
242+
)
243+
244+
function simulate(data, time_steps)
245+
@trace for i in 1:time_steps
246+
one_dim_heat_equation_time_step_sharded!(data)
247+
end
248+
end
249+
250+
@jit simulate(data, 100)
251+
```
252+
253+
:::
254+
255+
MPI to send the data. between computers When using GPUs on different devices, one needs to
256+
copy the data through the network via NCCL instead of the `cuda.
257+
258+
All devices from all nodes are available for use by Reactant. Given the topology of the
259+
devices, Reactant will automatically determine the right type of communication primitive to
260+
use to send data between the relevant nodes. For example, between GPUs on the same host
261+
Reactant may use the faster `cudaMemcpy` whereas for GPUs on different nodes Reactant will
262+
use NCCL.
263+
264+
One nice feature about how Reactant's handling of multiple devices is that you don't need to
265+
specify how the data is transfered. The fact that you doesn't need to specify how the
266+
communication is occuring enables code written with Reactant to be run on a different
267+
topology. For example, when using multiple GPUs on the same host it might be efficient to
268+
copy data using a `cudaMemcpy` to transfer between devices directly.
269+
270+
## Devices
271+
272+
You can query the available devices that Reactant can access as follows using
273+
[`Reactant.devices`](@ref).
274+
275+
```@example sharding_tutorial
276+
Reactant.devices()
277+
```
278+
279+
Not all devices are accessible from each process for [multi-node execution](@ref multihost).
280+
To query the devices accessible from the current process, use
281+
[`Reactant.addressable_devices`](@ref).
282+
283+
```@example sharding_tutorial
284+
Reactant.addressable_devices()
285+
```
286+
287+
You can inspect the type of the device, as well as its properties.
288+
289+
<!-- TODO: Generating Distributed Data by Concatenating Local-Worker Data -->
290+
291+
<!-- TODO: Handling Replicated Tensors -->
292+
293+
<!-- TODO: Sharding in Neural Networks -->
294+
295+
<!-- TODO: 8-way Batch Parallelism -->
296+
297+
<!-- TODO: 4-way Batch & 2-way Model Parallelism -->
298+
299+
## Related links
300+
301+
1. [Shardy Documentation](https://openxla.org/shardy)
302+
2. [Jax Documentation](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
303+
3. [Jax Scaling Book](https://jax-ml.github.io/scaling-book/sharding/)
304+
4. [HuggingFace Ultra Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook)

0 commit comments

Comments
 (0)