You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- Replace deprecated launch utility with torchrun (see PyTorch docs: https://pytorch.org/docs/stable/distributed.html#launch-utility)
- Update README to reflect torchrun usage
- Remove main.py (no longer referenced in documentation)
- Update CI to test example.py script instead
Signed-off-by: jafraustro <[email protected]>
# Distributed Data Parallel (DDP) Applications with PyTorch
7
3
8
-
# Prerequisites
4
+
This guide demonstrates how to structure a distributed model training application for convenient multi-node launches using `torchrun`.
9
5
10
-
We assume you are familiar with [PyTorch](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html), the primitives it provides for [writing distributed applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html) as well as training [distributed models](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
6
+
---
11
7
12
-
The example program in this tutorial uses the
13
-
[`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) class for training models
14
-
in a _data parallel_ fashion: multiple workers train the same global
15
-
model by processing different portions of a large dataset, computing
16
-
local gradients (aka _sub_-gradients) independently and then
17
-
collectively synchronizing gradients using the AllReduce primitive. In
18
-
HPC terminology, this model of execution is called _Single Program
19
-
Multiple Data_ or SPMD since the same application runs on all
20
-
application but each one operates on different portions of the
-[Distributed model training](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)
15
+
16
+
This tutorial uses the [`torch.nn.parallel.DistributedDataParallel`](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) (DDP) class for data parallel training: multiple workers train the same global model on different data shards, compute local gradients, and synchronize them using AllReduce. In High-Performance Computing (HCP), this is called _Single Program Multiple Data_ (SPMD).
17
+
18
+
---
19
+
20
+
## Application Process Topologies
24
21
25
22
A Distributed Data Parallel (DDP) application can be executed on
26
23
multiple nodes where each node can consist of multiple GPU
27
24
devices. Each node in turn can run multiple copies of the DDP
28
25
application, each of which processes its models on multiple GPUs.
29
26
30
-
Let_N_ be the number of nodes on which the application is running and
31
-
_G_ be the number of GPUs per node. The total number of application
32
-
processes running across all the nodes at one time is called the
33
-
**World Size**, _W_ and the number of processes running on each node
34
-
is referred to as the **Local World Size**, _L_.
27
+
Let:
28
+
-_N_ = number of nodes
29
+
-_G_ = number of GPUs per node
30
+
-_W_ = **World Size** = total number of processes
31
+
-_L_ = **Local World Size** = processes per node
35
32
36
-
Each application process is assigned two IDs: a _local_ rank in \[0,
37
-
_L_-1\] and a _global_ rank in \[0, _W_-1\].
33
+
Each process has:
34
+
-**Local rank**: in `[0, L-1]`
35
+
-**Global rank**: in `[0, W-1]`
38
36
39
-
To illustrate the terminology defined above, consider the case where a
40
-
DDP application is launched on two nodes, each of which has four
41
-
GPUs. We would then like each process to span two GPUs each. The
42
-
mapping of processes to nodes is shown in the figure below:
37
+
**Example:**
38
+
If you launch a DDP app on 2 nodes, each with 4 GPUs, and want each process to span 2 GPUs, the mapping is as follows:
While there are quite a few ways to map processes to nodes, a good
47
-
rule of thumb is to have one process span a single GPU. This enables
48
-
the DDP application to have as many parallel reader streams as there
49
-
are GPUs and in practice provides a good balance between I/O and
50
-
computational costs. In the rest of this tutorial, we assume that the
51
-
application follows this heuristic.
42
+
While there are quite a few ways to map processes to nodes, a good rule of thumb is to have one process span a single GPU. This enables the DDP application to have as many parallel reader streams as there are GPUs and in practice provides a good balance between I/O and computational costs. In the rest of this tutorial, we assume that the application follows this heuristic.
52
43
53
44
# Preparing and launching a DDP application
54
45
55
-
Independent of how a DDP application is launched, each process needs a
56
-
mechanism to know its global and local ranks. Once this is known, all
57
-
processes create a `ProcessGroup` that enables them to participate in
58
-
collective communication operations such as AllReduce.
59
-
60
-
A convenient way to start multiple DDP processes and initialize all
61
-
values needed to create a `ProcessGroup` is to use the distributed
62
-
`launch.py` script provided with PyTorch. The launcher can be found
63
-
under the `distributed` subdirectory under the local `torch`
64
-
installation directory. Here is a quick way to get the path of
When the DDP application is started via `launch.py`, it passes the world size, global rank, master address and master port via environment variables and the local rank as a command-line parameter to each instance.
78
-
To use the launcher, an application needs to adhere to the following convention:
79
-
80
-
1. It must provide an entry-point function for a _single worker_. For example, it should not launch subprocesses using `torch.multiprocessing.spawn`
81
-
2. It must use environment variables for initializing the process group.
82
-
83
-
For simplicity, the application can assume each process maps to a single GPU but in the next section we also show how a more general process-to-GPU mapping can be performed.
46
+
Independent of how a DDP application is launched, each process needs a mechanism to know its global and local ranks. Once this is known, all processes create a `ProcessGroup` that enables them to participate in collective communication operations such as AllReduce.
84
47
85
-
# Sample application
48
+
A convenient way to start multiple DDP processes and initialize all values needed to create a `ProcessGroup` is to use the [`torchrun`](https://docs.pytorch.org/docs/stable/elastic/run.html) script provided with PyTorch.
86
49
87
-
The sample DDP application in this repo is based on the "Hello, World" [DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
50
+
---
88
51
89
-
## Argument passing convention
52
+
## Sample Application
90
53
91
-
The DDP application takes two command-line arguments:
54
+
This example is based on the ["Hello, World" DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
92
55
93
-
1.`--local_rank`: This is passed in via `launch.py`
94
-
2.`--local_world_size`: This is passed in explicitly and is typically either $1$ or the number of GPUs per node.
56
+
The application calls the `spmd_main` entrypoint:
95
57
96
-
The application parses these and calls the `spmd_main` entrypoint:
In `spmd_main`, the process group is initialized with just the backend (NCCL or Gloo). The rest of the information needed for rendezvous comes from environment variables set by `launch.py`:
63
+
In `spmd_main`, the process group is initialized using the Accelerator API. The rest of the rendezvous information comes from environment variables set by `torchrun`:
108
64
109
-
```py
110
-
defspmd_main(local_world_size, local_rank):
65
+
```python
66
+
defspmd_main():
111
67
# These are the parameters used to initialize the process group
112
68
env_dict = {
113
69
key: os.environ[key]
114
70
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
Given the local rank and world size, the training function, `demo_basic` initializes the `DistributedDataParallel` model across a set of GPUs local to the node via `device_ids`:
89
+
**Key points:**
90
+
- Each process reads its rank and world size from environment variables.
91
+
- The process group is initialized for distributed communication.
130
92
131
-
```py
132
-
defdemo_basic(local_world_size, local_rank):
133
-
134
-
# setup devices for this process. For local_world_size = 2, num_gpus = 8,
135
-
# rank 0 uses GPUs [0, 1, 2, 3] and
136
-
# rank 1 uses GPUs [4, 5, 6, 7].
137
-
n = torch.cuda.device_count() // local_world_size
138
-
device_ids =list(range(local_rank * n, (local_rank +1) * n))
93
+
The training function, `demo_basic`, initializes the DDP model on the appropriate GPU:
139
94
95
+
```python
96
+
defdemo_basic(rank):
140
97
print(
141
-
f"[{os.getpid()}] rank = {dist.get_rank()}, "
142
-
+f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using the launcher, the mechanics of setting up distributed training can be significantly simplified.
161
+
As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using `torchrun`, the mechanics of setting up distributed training can be significantly simplified.
0 commit comments