Skip to content

Commit d3996ad

Browse files
SkafteNickiBorda
andauthored
Impr docs on batch sizes and limits in distributed (#21070)
* improve gpu_faq * add per device text * header * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]>
1 parent 5071a04 commit d3996ad

File tree

6 files changed

+66
-24
lines changed

6 files changed

+66
-24
lines changed

docs/source-pytorch/accelerators/gpu_faq.rst

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,71 @@
55
GPU training (FAQ)
66
==================
77

8-
******************************************************************
9-
How should I adjust the learning rate when using multiple devices?
10-
******************************************************************
8+
***************************************************************
9+
How should I adjust the batch size when using multiple devices?
10+
***************************************************************
1111

12-
When using distributed training make sure to modify your learning rate according to your effective
13-
batch size.
12+
Lightning automatically shards your data across multiple GPUs, meaning that each device only sees a unique subset of your
13+
data, but the `batch_size` in your DataLoader remains the same. This means that the effective batch size e.g. the
14+
total number of samples processed in one forward/backward pass is
1415

15-
Let's say you have a batch size of 7 in your dataloader.
16+
.. math::
1617
17-
.. testcode::
18+
\text{Effective Batch Size} = \text{DataLoader Batch Size} \times \text{Number of Devices} \times \text{Number of Nodes}
1819
19-
class LitModel(LightningModule):
20-
def train_dataloader(self):
21-
return Dataset(..., batch_size=7)
22-
23-
Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes.
20+
A couple of examples to illustrate this:
2421

2522
.. code-block:: python
2623
27-
# effective batch size = 7 * 8
24+
dataloader = DataLoader(..., batch_size=7)
25+
26+
# Single GPU: effective batch size = 7
27+
Trainer(accelerator="gpu", devices=1)
28+
29+
# Multi-GPU: effective batch size = 7 * 8 = 56
2830
Trainer(accelerator="gpu", devices=8, strategy=...)
2931
30-
# effective batch size = 7 * 8 * 10
32+
# Multi-node: effective batch size = 7 * 8 * 10 = 560
3133
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...)
3234
35+
In general you should be able to use the same `batch_size` in your DataLoader regardless of the number of devices you are
36+
using.
37+
38+
.. note::
39+
40+
If you want distributed training to work exactly the same as single GPU training, you need to set the `batch_size`
41+
in your DataLoader to `original_batch_size / num_devices` to maintain the same effective batch size. However, this
42+
can lead to poor GPU utilization.
43+
44+
----
45+
46+
******************************************************************
47+
How should I adjust the learning rate when using multiple devices?
48+
******************************************************************
49+
50+
Because the effective batch size is larger when using multiple devices, you need to adjust your learning rate
51+
accordingly. Because the learning rate is a hyperparameter that controls how much to change the model in response to
52+
the estimated error each time the model weights are updated, it is important to scale it with the effective batch size.
53+
54+
In general, there are two common scaling rules:
55+
56+
1. **Linear scaling**: Increase the learning rate linearly with the number of devices.
57+
58+
.. code-block:: python
59+
60+
# Example: Linear scaling
61+
base_lr = 1e-3
62+
num_devices = 8
63+
scaled_lr = base_lr * num_devices # 8e-3
64+
65+
2. **Square root scaling**: Increase the learning rate by the square root of the number of devices.
66+
67+
.. code-block:: python
68+
69+
# Example: Square root scaling
70+
base_lr = 1e-3
71+
num_devices = 8
72+
scaled_lr = base_lr * (num_devices ** 0.5) # 2.83e-3
3373
3474
.. note:: Huge batch sizes are actually really bad for convergence. Check out:
3575
`Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour <https://arxiv.org/abs/1706.02677>`_

docs/source-pytorch/common/trainer.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ limit_train_batches
510510

511511
How much of training dataset to check.
512512
Useful when debugging or testing something that happens at the end of an epoch.
513+
Value is per device.
513514

514515
.. testcode::
515516

@@ -535,7 +536,7 @@ limit_test_batches
535536
:width: 400
536537
:muted:
537538

538-
How much of test dataset to check.
539+
How much of test dataset to check. Value is per device.
539540

540541
.. testcode::
541542

@@ -560,6 +561,7 @@ limit_val_batches
560561

561562
How much of validation dataset to check.
562563
Useful when debugging or testing something that happens at the end of an epoch.
564+
Value is per device.
563565

564566
.. testcode::
565567

docs/source-pytorch/expertise_levels.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Learn to scale up your models and enable collaborative model development at acad
8484
.. Add callout items below this line
8585
8686
.. displayitem::
87-
:header: Level 7: Interactive cloud development
87+
:header: Level 7: Hardware acceleration
8888
:description: Learn how to access GPUs and TPUs on the cloud.
8989
:button_link: levels/intermediate_level_7.html
9090
:col_css: col-md-6

docs/source-pytorch/levels/intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Learn to scale up your models and enable collaborative model development at acad
1616
.. Add callout items below this line
1717
1818
.. displayitem::
19-
:header: Level 7: Interactive cloud development
19+
:header: Level 7: Hardware acceleration
2020
:description: Learn how to access GPUs and TPUs on the cloud.
2121
:button_link: intermediate_level_7.html
2222
:col_css: col-md-6

docs/source-pytorch/levels/intermediate_level_7.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
:orphan:
22

3-
######################################
4-
Level 7: Interactive cloud development
5-
######################################
3+
##############################
4+
Level 7: Hardware acceleration
5+
##############################
66

77
Learn to develop models on cloud GPUs and TPUs.
88

src/lightning/pytorch/trainer/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,16 @@ def __init__(
185185
:class:`datetime.timedelta`.
186186
187187
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
188-
Default: ``1.0``.
188+
Value is per device. Default: ``1.0``.
189189
190190
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches).
191-
Default: ``1.0``.
191+
Value is per device. Default: ``1.0``.
192192
193193
limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches).
194-
Default: ``1.0``.
194+
Value is per device. Default: ``1.0``.
195195
196196
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches).
197-
Default: ``1.0``.
197+
Value is per device. Default: ``1.0``.
198198
199199
overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int).
200200
Default: ``0.0``.

0 commit comments

Comments
 (0)