Skip to content

Commit 7eff003

Browse files
authored
update docs (#9903)
1 parent f2b0db6 commit 7eff003

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

docs/source/advanced/multi_gpu.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,28 +611,34 @@ Let's say you have a batch size of 7 in your dataloader.
611611
def train_dataloader(self):
612612
return Dataset(..., batch_size=7)
613613

614-
In DDP or Horovod your effective batch size will be 7 * gpus * num_nodes.
614+
In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED, or Horovod your effective batch size will be 7 * gpus * num_nodes.
615615

616616
.. code-block:: python
617617
618618
# effective batch size = 7 * 8
619619
Trainer(gpus=8, accelerator="ddp")
620+
Trainer(gpus=8, accelerator="ddp_spawn")
621+
Trainer(gpus=8, accelerator="ddp_sharded")
620622
Trainer(gpus=8, accelerator="horovod")
621623
622624
# effective batch size = 7 * 8 * 10
623625
Trainer(gpus=8, num_nodes=10, accelerator="ddp")
626+
Trainer(gpus=8, num_nodes=10, accelerator="ddp_spawn")
627+
Trainer(gpus=8, num_nodes=10, accelerator="ddp_sharded")
624628
Trainer(gpus=8, num_nodes=10, accelerator="horovod")
625629
626-
In DDP2, your effective batch size will be 7 * num_nodes.
630+
In DDP2 or DP, your effective batch size will be 7 * num_nodes.
627631
The reason is that the full batch is visible to all GPUs on the node when using DDP2.
628632

629633
.. code-block:: python
630634
631635
# effective batch size = 7
632636
Trainer(gpus=8, accelerator="ddp2")
637+
Trainer(gpus=8, accelerator="dp")
633638
634639
# effective batch size = 7 * 10
635640
Trainer(gpus=8, num_nodes=10, accelerator="ddp2")
641+
Trainer(gpus=8, accelerator="dp")
636642
637643
638644
.. note:: Huge batch sizes are actually really bad for convergence. Check out:

pytorch_lightning/callbacks/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,5 @@ def on_before_optimizer_step(
327327
pass
328328

329329
def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
330-
"""Called after ``optimizer.step()`` and before ``optimizer.zero_grad()``."""
330+
"""Called before ``optimizer.zero_grad()``."""
331331
pass

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ class ModelCheckpoint(Callback):
130130
131131
Use ``every_n_epochs`` instead.
132132
133-
134133
Note:
135134
For extra customization, ModelCheckpoint includes the following attributes:
136135

pytorch_lightning/loggers/wandb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __init__(
282282
rank_zero_warn(
283283
f"Providing log_model={log_model} requires wandb version >= 0.10.22"
284284
" for logging associated model metadata.\n"
285-
"Hint: Upgrade with `pip install --ugrade wandb`."
285+
"Hint: Upgrade with `pip install --upgrade wandb`."
286286
)
287287

288288
super().__init__()

0 commit comments

Comments
 (0)