|
5 | 5 | GPU training (FAQ)
|
6 | 6 | ==================
|
7 | 7 |
|
8 |
| -****************************************************************** |
9 |
| -How should I adjust the learning rate when using multiple devices? |
10 |
| -****************************************************************** |
| 8 | +*************************************************************** |
| 9 | +How should I adjust the batch size when using multiple devices? |
| 10 | +*************************************************************** |
11 | 11 |
|
12 |
| -When using distributed training make sure to modify your learning rate according to your effective |
13 |
| -batch size. |
| 12 | +Lightning automatically shards your data across multiple GPUs, meaning that each device only sees a unique subset of your |
| 13 | +data, but the `batch_size` in your DataLoader remains the same. This means that the effective batch size e.g. the |
| 14 | +total number of samples processed in one forward/backward pass is |
14 | 15 |
|
15 |
| -Let's say you have a batch size of 7 in your dataloader. |
| 16 | +.. math:: |
16 | 17 |
|
17 |
| -.. testcode:: |
| 18 | + \text{Effective Batch Size} = \text{DataLoader Batch Size} \times \text{Number of Devices} \times \text{Number of Nodes} |
18 | 19 |
|
19 |
| - class LitModel(LightningModule): |
20 |
| - def train_dataloader(self): |
21 |
| - return Dataset(..., batch_size=7) |
22 |
| - |
23 |
| -Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes. |
| 20 | +A couple of examples to illustrate this: |
24 | 21 |
|
25 | 22 | .. code-block:: python
|
26 | 23 |
|
27 |
| - # effective batch size = 7 * 8 |
| 24 | + dataloader = DataLoader(..., batch_size=7) |
| 25 | +
|
| 26 | + # Single GPU: effective batch size = 7 |
| 27 | + Trainer(accelerator="gpu", devices=1) |
| 28 | +
|
| 29 | + # Multi-GPU: effective batch size = 7 * 8 = 56 |
28 | 30 | Trainer(accelerator="gpu", devices=8, strategy=...)
|
29 | 31 |
|
30 |
| - # effective batch size = 7 * 8 * 10 |
| 32 | + # Multi-node: effective batch size = 7 * 8 * 10 = 560 |
31 | 33 | Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...)
|
32 | 34 |
|
| 35 | +In general you should be able to use the same `batch_size` in your DataLoader regardless of the number of devices you are |
| 36 | +using. |
| 37 | + |
| 38 | +.. note:: |
| 39 | + |
| 40 | + If you want distributed training to work exactly the same as single GPU training, you need to set the `batch_size` |
| 41 | + in your DataLoader to `original_batch_size / num_devices` to maintain the same effective batch size. However, this |
| 42 | + can lead to poor GPU utilization. |
| 43 | + |
| 44 | +---- |
| 45 | + |
| 46 | +****************************************************************** |
| 47 | +How should I adjust the learning rate when using multiple devices? |
| 48 | +****************************************************************** |
| 49 | + |
| 50 | +Because the effective batch size is larger when using multiple devices, you need to adjust your learning rate |
| 51 | +accordingly. Because the learning rate is a hyperparameter that controls how much to change the model in response to |
| 52 | +the estimated error each time the model weights are updated, it is important to scale it with the effective batch size. |
| 53 | + |
| 54 | +In general, there are two common scaling rules: |
| 55 | + |
| 56 | +1. **Linear scaling**: Increase the learning rate linearly with the number of devices. |
| 57 | + |
| 58 | + .. code-block:: python |
| 59 | +
|
| 60 | + # Example: Linear scaling |
| 61 | + base_lr = 1e-3 |
| 62 | + num_devices = 8 |
| 63 | + scaled_lr = base_lr * num_devices # 8e-3 |
| 64 | +
|
| 65 | +2. **Square root scaling**: Increase the learning rate by the square root of the number of devices. |
| 66 | + |
| 67 | + .. code-block:: python |
| 68 | +
|
| 69 | + # Example: Square root scaling |
| 70 | + base_lr = 1e-3 |
| 71 | + num_devices = 8 |
| 72 | + scaled_lr = base_lr * (num_devices ** 0.5) # 2.83e-3 |
33 | 73 |
|
34 | 74 | .. note:: Huge batch sizes are actually really bad for convergence. Check out:
|
35 | 75 | `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour <https://arxiv.org/abs/1706.02677>`_
|
|
0 commit comments