Skip to content

Commit c08b2c1

Browse files
authored
Merge branch 'master' into improve-docs-resume-checkpoints
2 parents 30a8f0c + 9983f3a commit c08b2c1

File tree

16 files changed

+164
-30
lines changed

16 files changed

+164
-30
lines changed

.actions/assistant.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:
483483

484484

485485
if __name__ == "__main__":
486+
import sys
487+
486488
import jsonargparse
489+
from jsonargparse import ArgumentParser
490+
491+
def patch_jsonargparse_python_3_12_8():
492+
if sys.version_info < (3, 12, 8):
493+
return
494+
495+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
496+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
497+
return namespace, args
498+
499+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
500+
501+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
487502

488503
jsonargparse.CLI(AssistantCLI, as_positional=False)

.github/workflows/ci-tests-fabric.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ jobs:
4949
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5050
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
52-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
53-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
54-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
55-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
57-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
52+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
53+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
54+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
55+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
56+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
57+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5858
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
59-
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
59+
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# "oldest" versions tests, only on minimum Python
6363
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6464
- {

.github/workflows/ci-tests-pytorch.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ jobs:
5353
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5454
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
56-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
57-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
58-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
59-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
57+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
58+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
59+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
63-
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
64-
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
65-
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
63+
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
64+
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
65+
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
6666
# "oldest" versions tests, only on minimum Python
6767
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6868
- {

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ RUN \
5959
add-apt-repository ppa:deadsnakes/ppa && \
6060
apt-get install -y \
6161
python${PYTHON_VERSION} \
62-
python${PYTHON_VERSION}-distutils \
62+
python3-setuptools \
6363
python${PYTHON_VERSION}-dev \
6464
&& \
6565
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \

dockers/docs/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ RUN \
4444
dvipng \
4545
texlive-pictures \
4646
python3 \
47-
python3-distutils \
47+
python3-setuptools \
4848
python3-dev \
4949
&& \
5050
update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \

docs/source-pytorch/common/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ How-to Guides
202202
:col_css: col-md-4
203203
:height: 180
204204

205+
.. displayitem::
206+
:header: Truncated Back-Propagation Through Time
207+
:description: Efficiently step through time when training recurrent models
208+
:button_link: ../common/tbtt.html
209+
:col_css: col-md-4
210+
:height: 180
211+
205212
.. raw:: html
206213

207214
</div>
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
##############################################
2+
Truncated Backpropagation Through Time (TBPTT)
3+
##############################################
4+
5+
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
6+
a much longer sequence. This is made possible by passing training batches
7+
split along the time-dimensions into splits of size k to the
8+
``training_step``. In order to keep the same forward propagation behavior, all
9+
hidden states should be kept in-between each time-dimension split.
10+
11+
12+
.. code-block:: python
13+
14+
import torch
15+
import torch.optim as optim
16+
import pytorch_lightning as pl
17+
from pytorch_lightning import LightningModule
18+
19+
class LitModel(LightningModule):
20+
21+
def __init__(self):
22+
super().__init__()
23+
24+
# 1. Switch to manual optimization
25+
self.automatic_optimization = False
26+
27+
self.truncated_bptt_steps = 10
28+
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
29+
30+
# 2. Remove the `hiddens` argument
31+
def training_step(self, batch, batch_idx):
32+
33+
# 3. Split the batch in chunks along the time dimension
34+
split_batches = split_batch(batch, self.truncated_bptt_steps)
35+
36+
batch_size = 10
37+
hidden_dim = 20
38+
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
39+
for split_batch in range(split_batches):
40+
# 4. Perform the optimization in a loop
41+
loss, hiddens = self.my_rnn(split_batch, hiddens)
42+
self.backward(loss)
43+
self.optimizer.step()
44+
self.optimizer.zero_grad()
45+
46+
# 5. "Truncate"
47+
hiddens = hiddens.detach()
48+
49+
# 6. Remove the return of `hiddens`
50+
# Returning loss in manual optimization is not needed
51+
return None
52+
53+
def configure_optimizers(self):
54+
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
55+
56+
if __name__ == "__main__":
57+
model = LitModel()
58+
trainer = pl.Trainer(max_epochs=5)
59+
trainer.fit(model, train_dataloader) # Define your own dataloader

docs/source-pytorch/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def _load_py_module(name: str, location: str) -> ModuleType:
462462
("py:obj", "lightning.pytorch.utilities.memory.is_out_of_cpu_memory"),
463463
("py:func", "lightning.pytorch.utilities.rank_zero.rank_zero_only"),
464464
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfig"),
465-
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"),
465+
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfigType"),
466+
("py:class", "lightning.pytorch.utilities.types.OptimizerConfigType"),
467+
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfigType"),
466468
("py:class", "lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin"),
467469
("py:class", "lightning_habana.pytorch.strategies.HPUDDPStrategy"),
468470
("py:class", "lightning_habana.pytorch.strategies.HPUParallelStrategy"),

examples/fabric/reinforcement_learning/rl/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import math
33
import os
4-
from distutils.util import strtobool
54
from typing import TYPE_CHECKING, Optional, Union
65

76
import gymnasium as gym
@@ -12,6 +11,23 @@
1211
from rl.agent import PPOAgent, PPOLightningAgent
1312

1413

14+
def strtobool(val):
15+
"""Convert a string representation of truth to true (1) or false (0).
16+
17+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
18+
Raises ValueError if 'val' is anything else.
19+
20+
Note: taken from distutils after its deprecation.
21+
22+
"""
23+
val = val.lower()
24+
if val in ("y", "yes", "t", "true", "on", "1"):
25+
return 1
26+
if val in ("n", "no", "f", "false", "off", "0"):
27+
return 0
28+
raise ValueError(f"invalid truth value {val!r}")
29+
30+
1531
def parse_args():
1632
parser = argparse.ArgumentParser()
1733
parser.add_argument("--exp-name", type=str, default="default", help="the name of this experiment")

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import platform
1919
from collections.abc import Mapping
2020
from contextlib import AbstractContextManager, ExitStack
21+
from datetime import timedelta
2122
from itertools import chain
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -29,6 +30,7 @@
2930
from typing_extensions import override
3031

3132
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
33+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3234
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
3335
from lightning.fabric.plugins.precision import Precision
3436
from lightning.fabric.strategies.ddp import DDPStrategy
@@ -97,6 +99,7 @@ def __init__(
9799
load_full_weights: bool = False,
98100
precision: Optional[Precision] = None,
99101
process_group_backend: Optional[str] = None,
102+
timeout: Optional[timedelta] = default_pg_timeout,
100103
) -> None:
101104
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
102105
billion parameter models. `For more information: https://pytorch-
@@ -241,6 +244,7 @@ def __init__(
241244
process_group_backend=process_group_backend,
242245
)
243246
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
247+
self._timeout: Optional[timedelta] = timeout
244248

245249
self.config = self._load_config(config)
246250
if self.config is None:
@@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None:
648652
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
649653
)
650654
self._process_group_backend = self._get_process_group_backend()
651-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
655+
deepspeed.init_distributed(
656+
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
657+
)
652658

653659
def _set_node_environment_variables(self) -> None:
654660
assert self.cluster_environment is not None

0 commit comments

Comments
 (0)