|
| 1 | +# [Automatic Sharding-based Distributed Parallelism](@id sharding) |
| 2 | + |
| 3 | +!!! tip "Use XLA IFRT Runtime" |
| 4 | + |
| 5 | + While PJRT does support some minimal sharding capabilities on CUDA GPUs, sharding |
| 6 | + support in Reactant is primarily provided via IFRT. Before loading Reactant, set the |
| 7 | + "xla_runtime" preference to be "IFRT". This can be done with: |
| 8 | + |
| 9 | + ```julia |
| 10 | + using Preferences, UUIDs |
| 11 | + |
| 12 | + Preferences.set_preference!( |
| 13 | + UUID("3c362404-f566-11ee-1572-e11a4b42c853"), |
| 14 | + "xla_runtime" => "IFRT" |
| 15 | + ) |
| 16 | + ``` |
| 17 | + |
| 18 | +## Basics |
| 19 | + |
| 20 | +Sharding is one mechanism supported within Reactant that tries to make it easy to program |
| 21 | +for multiple devices (including [multiple nodes](@ref distributed)). |
| 22 | + |
| 23 | +```@example sharding_tutorial |
| 24 | +using Reactant |
| 25 | +
|
| 26 | +@assert length(Reactant.devices()) > 1 # hide |
| 27 | +Reactant.devices() |
| 28 | +``` |
| 29 | + |
| 30 | +Sharding provides Reactant users a |
| 31 | +[PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space) |
| 32 | +programming model. Let's understand what this means through example. |
| 33 | + |
| 34 | +Suppose we have a function that takes a large input array and computes sin for all elements |
| 35 | +of the array. |
| 36 | + |
| 37 | +```@example sharding_tutorial |
| 38 | +function big_sin(data) |
| 39 | + data .= sin.(data) |
| 40 | + return nothing |
| 41 | +end |
| 42 | +
|
| 43 | +N = 1600 |
| 44 | +x = Reactant.to_rarray(reshape(collect(Float32, 1:N), 40, 40)) |
| 45 | +
|
| 46 | +compiled_big_sin = @compile big_sin(x) |
| 47 | +
|
| 48 | +compiled_big_sin(x) |
| 49 | +``` |
| 50 | + |
| 51 | +This successfully allocates the array `x` on one device, and executes it on the same device. |
| 52 | +However, suppose we want to execute this computation on multiple devices. Perhaps this is |
| 53 | +because the size of our inputs (`N`) is too large to fit on a single device. Or |
| 54 | +alternatively the function we execute is computationally expensive and we want to leverage |
| 55 | +the computing power of multiple devices. |
| 56 | + |
| 57 | +Unlike more explicit communication libraries like MPI, the sharding model used by Reactant |
| 58 | +aims to let you execute a program on multiple devices without significant modifications to |
| 59 | +the single-device program. In particular, you do not need to write explicit communication |
| 60 | +calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a |
| 61 | +very large single-node and Reactant will automatically determine how to subdivide the data, |
| 62 | +computation, and required communication. |
| 63 | + |
| 64 | +When using sharding, the one thing you need to change about your code is how arrays are |
| 65 | +allocated. In particular, you need to specify how the array is partitioned amongst available |
| 66 | +devices. For example, suppose you are on a machine with 4 GPUs. In the example above, we |
| 67 | +computed `sin` for all elements of a 40x40 grid. One partitioning we could select is to have |
| 68 | +it partitioned along the first axis, such that each GPU has a slice of 10x40 elements. We |
| 69 | +could accomplish this as follows. No change is required to the original function. However, |
| 70 | +the compiled function is specific to the sharding so we need to compile a new version for |
| 71 | +our sharded array. |
| 72 | + |
| 73 | +```@example sharding_tutorial |
| 74 | +N = 1600 |
| 75 | +
|
| 76 | +x_sharded_first = Reactant.to_rarray( |
| 77 | + reshape(collect(1:N), 40, 40), |
| 78 | + sharding=Sharding.NamedSharding( |
| 79 | + Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y)), |
| 80 | + (:x, nothing) |
| 81 | + ) |
| 82 | +) |
| 83 | +
|
| 84 | +compiled_big_sin_sharded_first = @compile big_sin(x_sharded_first) |
| 85 | +
|
| 86 | +compiled_big_sin_sharded_first(x_sharded_first) |
| 87 | +``` |
| 88 | + |
| 89 | +Alternatively, we can parition the data in a different form. In particular, we could |
| 90 | +subdivide the data on both axes. As a result each GPU would have a slice of 20x20 elements. |
| 91 | +Again no change is required to the original function, but we would change the allocation as |
| 92 | +follows: |
| 93 | + |
| 94 | +```@example sharding_tutorial |
| 95 | +N = 1600 |
| 96 | +x_sharded_both = Reactant.to_rarray( |
| 97 | + reshape(collect(1:N), 40, 40), |
| 98 | + sharding=Sharding.NamedSharding( |
| 99 | + Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y)), |
| 100 | + (:x, :y) |
| 101 | + ) |
| 102 | +) |
| 103 | +
|
| 104 | +compiled_big_sin_sharded_both = @compile big_sin(x_sharded_both) |
| 105 | +
|
| 106 | +compiled_big_sin_sharded_both(x_sharded_both) |
| 107 | +``` |
| 108 | + |
| 109 | +Sharding in reactant requires you to specify how the data is sharded across devices on a |
| 110 | +mesh. We start by specifying the mesh [`Sharding.Mesh`](@ref) which is a collection of the |
| 111 | +devices reshaped into an N-D grid. Additionally, we can specify names for each axis of the |
| 112 | +mesh, that are then referenced when specifying how the data is sharded. |
| 113 | + |
| 114 | +1. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y))`: Creates a 2D grid of 4 |
| 115 | + devices arranged in a 2x2 grid. The first axis is named `:x` and the second axis is named |
| 116 | + `:y`. |
| 117 | +2. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y))`: Creates a 2D grid of 4 |
| 118 | + devices arranged in a 4x1 grid. The first axis is named `:x` and the second axis is |
| 119 | + named `:y`. |
| 120 | + |
| 121 | +Given the mesh, we will specify how the data is sharded across the devices. |
| 122 | + |
| 123 | +<!-- |
| 124 | +TODO describe how arrays are the "global data arrays, even though data is itself only stored |
| 125 | +on relevant device and computation is performed only devices with the required data |
| 126 | +(effectively showing under the hood how execution occurs) |
| 127 | +--> |
| 128 | + |
| 129 | +<!-- |
| 130 | +TODO make a simple conway's game of life, or heat equation using sharding simulation example |
| 131 | +to show how a ``typical MPI'' simulation can be written using sharding. |
| 132 | +--> |
| 133 | + |
| 134 | +## Simple 1-Dimensional Heat Equation |
| 135 | + |
| 136 | +So far we chose a function which was perfectly parallelizable (e.g. each elemnt of the array |
| 137 | +only accesses its own data). Let's consider a more realistic example where an updated |
| 138 | +element requires data from its neighbors. In the distributed case, this requires |
| 139 | +communicating the data along the boundaries. |
| 140 | + |
| 141 | +In particular, let's implement a one-dimensional |
| 142 | +[heat equation](https://en.wikipedia.org/wiki/Heat_equation) simulation. In this code you |
| 143 | +initialize the temperature of all points of the simulation and over time the code will |
| 144 | +simulate how the heat is transfered across space. In particular points of high temperature |
| 145 | +will transfer energy to points of low energy. |
| 146 | + |
| 147 | +As an example, here is a visualization of a 2-dimensional heat equation: |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | +<!-- TODO we should animate the above -- and even more ideally have one we generate ourselves. --> |
| 152 | + |
| 153 | +To keep things simple, let's implement a 1-dimensional heat equation here. We start off with |
| 154 | +an array for the temperature at each point, and will compute the next version of the |
| 155 | +temperatures according to the equation |
| 156 | +`x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]`. |
| 157 | + |
| 158 | +Let's consider how this can be implemented with explicit MPI communication. Each node will |
| 159 | +contain a subset of the total data. For example, if we simulate with 100 points, and have 4 |
| 160 | +devices, each device will contain 25 data points. We're going to allocate some extra room at |
| 161 | +each end of the buffer to store the ``halo'', or the data at the boundary. Each time step |
| 162 | +that we take will first copy in the data from its neighbors into the halo via an explicit |
| 163 | +MPI send and recv call. We'll then compute the updated data for our slice of the data. |
| 164 | + |
| 165 | +With sharding, things are a bit more simple. We can write the code as if we only had one |
| 166 | +device. No explicit send or recv's are necessary as they will be added automatically by |
| 167 | +Reactant when it deduces they are needed. In fact, Reactant will attempt to optimize the |
| 168 | +placement of the communicatinos to minimize total runtime. While Reactant tries to do a |
| 169 | +good job (which could be faster than an initial implementation -- especially for complex |
| 170 | +codebases), an expert may be able to find a better placement of the communication. |
| 171 | + |
| 172 | +The only difference for the sharded code again occurs during allocation. Here we explicitly |
| 173 | +specify that we want to subdivide the initial grid of 100 amongst all devices. Analagously |
| 174 | +if we had 4 devices to work with, each device would have 25 elements in its local storage. |
| 175 | +From the user's standpoint, however, all arrays give access to the entire dataset. |
| 176 | + |
| 177 | +::: code-group |
| 178 | + |
| 179 | +```julia [MPI Based Parallelism] |
| 180 | +function one_dim_heat_equation_time_step_mpi!(data) |
| 181 | + id = MPI.Comm_rank(MPI.COMM_WORLD) |
| 182 | + last_id = MPI.Comm_size(MPI.COMM_WORLD) |
| 183 | + |
| 184 | + # Send data right |
| 185 | + if id > 1 |
| 186 | + MPI.Send(@view(data[end]), MPI.COMM_WORLD; dest=id + 1) |
| 187 | + end |
| 188 | + |
| 189 | + # Recv data from left |
| 190 | + if id != last_id |
| 191 | + MPI.Recv(@view(data[1]), MPI.COMM_WORLD; dest=id - 1) |
| 192 | + end |
| 193 | + |
| 194 | + # 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1] |
| 195 | + data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end] |
| 196 | + |
| 197 | + return nothing |
| 198 | +end |
| 199 | + |
| 200 | + |
| 201 | +# Total size of grid we want to simulate |
| 202 | +N = 100 |
| 203 | + |
| 204 | +# Local size of grid (total size divided by number of MPI devices) |
| 205 | +_local = N / MPI.Comm_size(MPI.COMM_WORLD) |
| 206 | + |
| 207 | +# We add two to add a left side padding and right side padding, necessary for storing |
| 208 | +# boundaries from other nodes |
| 209 | +data = rand(_local + 2) |
| 210 | + |
| 211 | +function simulate(data, time_steps) |
| 212 | + for i in 1:time_steps |
| 213 | + one_dim_heat_equation_time_step_mpi!(data) |
| 214 | + end |
| 215 | +end |
| 216 | + |
| 217 | +simulate(data, 100) |
| 218 | +``` |
| 219 | + |
| 220 | +```julia [Sharded Parallelism] |
| 221 | +function one_dim_heat_equation_time_step_sharded!(data) |
| 222 | + # No send recv's required |
| 223 | + |
| 224 | + # 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1] |
| 225 | + # Reactant will automatically insert send and recv's |
| 226 | + data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end] |
| 227 | + |
| 228 | + return nothing |
| 229 | +end |
| 230 | + |
| 231 | +# Total size of grid we want to simulate |
| 232 | +N = 100 |
| 233 | + |
| 234 | +# Reactant's sharding handles distributing the data amongst devices, with each device |
| 235 | +# getting a corresponding fraction of the data |
| 236 | +data = Reactant.to_rarray( |
| 237 | + rand(N + 2); |
| 238 | + sharding=Sharding.NamedSharding( |
| 239 | + Sharding.Mesh(Reactant.devices(), (:x,)), |
| 240 | + (:x,) |
| 241 | + ) |
| 242 | +) |
| 243 | + |
| 244 | +function simulate(data, time_steps) |
| 245 | + @trace for i in 1:time_steps |
| 246 | + one_dim_heat_equation_time_step_sharded!(data) |
| 247 | + end |
| 248 | +end |
| 249 | + |
| 250 | +@jit simulate(data, 100) |
| 251 | +``` |
| 252 | + |
| 253 | +::: |
| 254 | + |
| 255 | +MPI to send the data. between computers When using GPUs on different devices, one needs to |
| 256 | +copy the data through the network via NCCL instead of the `cuda. |
| 257 | + |
| 258 | +All devices from all nodes are available for use by Reactant. Given the topology of the |
| 259 | +devices, Reactant will automatically determine the right type of communication primitive to |
| 260 | +use to send data between the relevant nodes. For example, between GPUs on the same host |
| 261 | +Reactant may use the faster `cudaMemcpy` whereas for GPUs on different nodes Reactant will |
| 262 | +use NCCL. |
| 263 | + |
| 264 | +One nice feature about how Reactant's handling of multiple devices is that you don't need to |
| 265 | +specify how the data is transfered. The fact that you doesn't need to specify how the |
| 266 | +communication is occuring enables code written with Reactant to be run on a different |
| 267 | +topology. For example, when using multiple GPUs on the same host it might be efficient to |
| 268 | +copy data using a `cudaMemcpy` to transfer between devices directly. |
| 269 | + |
| 270 | +## Devices |
| 271 | + |
| 272 | +You can query the available devices that Reactant can access as follows using |
| 273 | +[`Reactant.devices`](@ref). |
| 274 | + |
| 275 | +```@example sharding_tutorial |
| 276 | +Reactant.devices() |
| 277 | +``` |
| 278 | + |
| 279 | +Not all devices are accessible from each process for [multi-node execution](@ref multihost). |
| 280 | +To query the devices accessible from the current process, use |
| 281 | +[`Reactant.addressable_devices`](@ref). |
| 282 | + |
| 283 | +```@example sharding_tutorial |
| 284 | +Reactant.addressable_devices() |
| 285 | +``` |
| 286 | + |
| 287 | +You can inspect the type of the device, as well as its properties. |
| 288 | + |
| 289 | +<!-- TODO: Generating Distributed Data by Concatenating Local-Worker Data --> |
| 290 | + |
| 291 | +<!-- TODO: Handling Replicated Tensors --> |
| 292 | + |
| 293 | +<!-- TODO: Sharding in Neural Networks --> |
| 294 | + |
| 295 | +<!-- TODO: 8-way Batch Parallelism --> |
| 296 | + |
| 297 | +<!-- TODO: 4-way Batch & 2-way Model Parallelism --> |
| 298 | + |
| 299 | +## Related links |
| 300 | + |
| 301 | +1. [Shardy Documentation](https://openxla.org/shardy) |
| 302 | +2. [Jax Documentation](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) |
| 303 | +3. [Jax Scaling Book](https://jax-ml.github.io/scaling-book/sharding/) |
| 304 | +4. [HuggingFace Ultra Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook) |
0 commit comments