Skip to content

Commit 6ca2bac

Browse files
authored
Merge branch 'master' into deepspeed_mics_init
2 parents 5409bc9 + 030f36b commit 6ca2bac

File tree

11 files changed

+103
-10
lines changed

11 files changed

+103
-10
lines changed

.azure/gpu-benchmarks.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ jobs:
7575
pip list
7676
displayName: "Image info & NVIDIA"
7777
78-
- bash: pip install -e .[dev] --find-links ${TORCH_URL}
78+
- bash: |
79+
pip install -e .[dev] --find-links ${TORCH_URL}
80+
pip install setuptools==75.6.0
7981
env:
8082
FREEZE_REQUIREMENTS: "1"
8183
displayName: "Install package"

.azure/gpu-tests-fabric.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ jobs:
107107
- bash: |
108108
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
109109
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
110+
pip install setuptools==75.6.0
110111
displayName: "Install package & dependencies"
111112
112113
- bash: |

.azure/gpu-tests-pytorch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ jobs:
111111
- bash: |
112112
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
113113
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
114+
pip install setuptools==75.6.0
114115
displayName: "Install package & dependencies"
115116
116117
- bash: pip uninstall -y lightning

dockers/base-cuda/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ RUN \
5959
add-apt-repository ppa:deadsnakes/ppa && \
6060
apt-get install -y \
6161
python${PYTHON_VERSION} \
62-
python3-setuptools \
6362
python${PYTHON_VERSION}-dev \
6463
&& \
6564
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
@@ -79,6 +78,8 @@ RUN \
7978
curl https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \
8079
# Disable cache \
8180
pip config set global.cache-dir false && \
81+
# Install recent setuptools to obtain pkg_resources \
82+
pip install setuptools==75.6.0 && \
8283
# set particular PyTorch version \
8384
pip install -q wget packaging && \
8485
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py && \

dockers/release/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ RUN \
3939
fi && \
4040
# otherwise there is collision with folder name and pkg name on Pypi
4141
cd pytorch-lightning && \
42-
pip install setuptools && \
42+
pip install setuptools==75.6.0 && \
4343
PACKAGE_NAME=lightning pip install '.[extra,loggers,strategies]' --no-cache-dir && \
4444
PACKAGE_NAME=pytorch pip install '.[extra,loggers,strategies]' --no-cache-dir && \
4545
cd .. && \

docs/source-pytorch/common/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
../data/data
2424
../model/own_your_loop
2525
../advanced/model_init
26+
../common/tbptt
2627

2728

2829
#############
@@ -202,6 +203,13 @@ How-to Guides
202203
:col_css: col-md-4
203204
:height: 180
204205

206+
.. displayitem::
207+
:header: Truncated Back-Propagation Through Time
208+
:description: Efficiently step through time when training recurrent models
209+
:button_link: ../common/tbptt.html
210+
:col_css: col-md-4
211+
:height: 180
212+
205213
.. raw:: html
206214

207215
</div>
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
##############################################
2+
Truncated Backpropagation Through Time (TBPTT)
3+
##############################################
4+
5+
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
6+
a much longer sequence. This is made possible by passing training batches
7+
split along the time-dimensions into splits of size k to the
8+
``training_step``. In order to keep the same forward propagation behavior, all
9+
hidden states should be kept in-between each time-dimension split.
10+
11+
12+
.. code-block:: python
13+
14+
import torch
15+
import torch.optim as optim
16+
import pytorch_lightning as pl
17+
from pytorch_lightning import LightningModule
18+
19+
class LitModel(LightningModule):
20+
21+
def __init__(self):
22+
super().__init__()
23+
24+
# 1. Switch to manual optimization
25+
self.automatic_optimization = False
26+
27+
self.truncated_bptt_steps = 10
28+
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
29+
30+
# 2. Remove the `hiddens` argument
31+
def training_step(self, batch, batch_idx):
32+
33+
# 3. Split the batch in chunks along the time dimension
34+
split_batches = split_batch(batch, self.truncated_bptt_steps)
35+
36+
batch_size = 10
37+
hidden_dim = 20
38+
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
39+
for split_batch in range(split_batches):
40+
# 4. Perform the optimization in a loop
41+
loss, hiddens = self.my_rnn(split_batch, hiddens)
42+
self.backward(loss)
43+
self.optimizer.step()
44+
self.optimizer.zero_grad()
45+
46+
# 5. "Truncate"
47+
hiddens = hiddens.detach()
48+
49+
# 6. Remove the return of `hiddens`
50+
# Returning loss in manual optimization is not needed
51+
return None
52+
53+
def configure_optimizers(self):
54+
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
55+
56+
if __name__ == "__main__":
57+
model = LitModel()
58+
trainer = pl.Trainer(max_epochs=5)
59+
trainer.fit(model, train_dataloader) # Define your own dataloader

docs/source-pytorch/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def _load_py_module(name: str, location: str) -> ModuleType:
462462
("py:obj", "lightning.pytorch.utilities.memory.is_out_of_cpu_memory"),
463463
("py:func", "lightning.pytorch.utilities.rank_zero.rank_zero_only"),
464464
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfig"),
465-
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"),
465+
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfigType"),
466+
("py:class", "lightning.pytorch.utilities.types.OptimizerConfigType"),
467+
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfigType"),
466468
("py:class", "lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin"),
467469
("py:class", "lightning_habana.pytorch.strategies.HPUDDPStrategy"),
468470
("py:class", "lightning_habana.pytorch.strategies.HPUParallelStrategy"),

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import platform
1919
from collections.abc import Mapping
2020
from contextlib import AbstractContextManager, ExitStack
21+
from datetime import timedelta
2122
from itertools import chain
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -29,6 +30,7 @@
2930
from typing_extensions import override
3031

3132
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
33+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3234
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
3335
from lightning.fabric.plugins.precision import Precision
3436
from lightning.fabric.strategies.ddp import DDPStrategy
@@ -97,6 +99,7 @@ def __init__(
9799
load_full_weights: bool = False,
98100
precision: Optional[Precision] = None,
99101
process_group_backend: Optional[str] = None,
102+
timeout: Optional[timedelta] = default_pg_timeout,
100103
) -> None:
101104
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
102105
billion parameter models. `For more information: https://pytorch-
@@ -241,6 +244,7 @@ def __init__(
241244
process_group_backend=process_group_backend,
242245
)
243246
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
247+
self._timeout: Optional[timedelta] = timeout
244248

245249
self.config = self._load_config(config)
246250
if self.config is None:
@@ -662,7 +666,9 @@ def _init_deepspeed_distributed(self) -> None:
662666
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
663667
)
664668
self._process_group_backend = self._get_process_group_backend()
665-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
669+
deepspeed.init_distributed(
670+
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
671+
)
666672

667673
def _set_node_environment_variables(self) -> None:
668674
assert self.cluster_environment is not None

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections import OrderedDict
2020
from collections.abc import Generator, Mapping
2121
from contextlib import contextmanager
22+
from datetime import timedelta
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Optional, Union
2425

@@ -30,6 +31,7 @@
3031

3132
import lightning.pytorch as pl
3233
from lightning.fabric.plugins import ClusterEnvironment
34+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3335
from lightning.fabric.strategies import _StrategyRegistry
3436
from lightning.fabric.strategies.deepspeed import (
3537
_DEEPSPEED_AVAILABLE,
@@ -119,6 +121,7 @@ def __init__(
119121
load_full_weights: bool = False,
120122
precision_plugin: Optional[Precision] = None,
121123
process_group_backend: Optional[str] = None,
124+
timeout: Optional[timedelta] = default_pg_timeout,
122125
) -> None:
123126
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
124127
billion parameter models. `For more information: https://pytorch-
@@ -264,6 +267,7 @@ def __init__(
264267
precision_plugin=precision_plugin,
265268
process_group_backend=process_group_backend,
266269
)
270+
self._timeout: Optional[timedelta] = timeout
267271

268272
self.config = self._load_config(config)
269273
if self.config is None:
@@ -364,7 +368,9 @@ def _init_deepspeed_distributed(self) -> None:
364368
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
365369
)
366370
self._process_group_backend = self._get_process_group_backend()
367-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
371+
deepspeed.init_distributed(
372+
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
373+
)
368374

369375
def _set_node_environment_variables(self) -> None:
370376
assert self.cluster_environment is not None

0 commit comments

Comments
 (0)