Skip to content

Commit 6264e8f

Browse files
authored
Merge branch 'master' into batch_size_scaler_newargs
2 parents 074e238 + 7323bb8 commit 6264e8f

File tree

17 files changed

+164
-47
lines changed

17 files changed

+164
-47
lines changed

.github/workflows/_legacy-checkpoints.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
- uses: actions/checkout@v5
5959

6060
- name: Install uv and set Python version
61-
uses: astral-sh/setup-uv@v6
61+
uses: astral-sh/setup-uv@v7
6262
with:
6363
python-version: "3.9"
6464
# TODO: Avoid activating environment like this

.github/workflows/ci-tests-fabric.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
- uses: actions/checkout@v5
7272

7373
- name: Install uv and set Python version
74-
uses: astral-sh/setup-uv@v6
74+
uses: astral-sh/setup-uv@v7
7575
with:
7676
python-version: ${{ matrix.config.python-version || '3.9' }}
7777
# TODO: Avoid activating environment like this

.github/workflows/ci-tests-pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
- uses: actions/checkout@v5
7777

7878
- name: Install uv and set Python version
79-
uses: astral-sh/setup-uv@v6
79+
uses: astral-sh/setup-uv@v7
8080
with:
8181
python-version: ${{ matrix.config.python-version || '3.9' }}
8282
# TODO: Avoid activating environment like this

.github/workflows/code-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- uses: actions/checkout@v5
3232

3333
- name: Install uv and set Python version
34-
uses: astral-sh/setup-uv@v6
34+
uses: astral-sh/setup-uv@v7
3535
with:
3636
python-version: "3.11"
3737
# TODO: Avoid activating environment like this

.github/workflows/docs-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
lfs: ${{ matrix.pkg-name == 'pytorch' }}
7474

7575
- name: Install uv and set Python version
76-
uses: astral-sh/setup-uv@v6
76+
uses: astral-sh/setup-uv@v7
7777
with:
7878
python-version: "3.10"
7979
# TODO: Avoid activating environment like this

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ ci:
2323

2424
repos:
2525
- repo: https://github.com/pre-commit/pre-commit-hooks
26-
rev: v5.0.0
26+
rev: v6.0.0
2727
hooks:
2828
- id: end-of-file-fixer
2929
- id: trailing-whitespace
@@ -70,7 +70,7 @@ repos:
7070
- id: sphinx-lint
7171

7272
- repo: https://github.com/astral-sh/ruff-pre-commit
73-
rev: v0.12.2
73+
rev: v0.13.3
7474
hooks:
7575
# try to fix what is possible
7676
- id: ruff
@@ -95,8 +95,8 @@ repos:
9595
README.md
9696
)$
9797
98-
- repo: https://github.com/pre-commit/mirrors-prettier
99-
rev: v3.1.0
98+
- repo: https://github.com/JoC0de/pre-commit-prettier
99+
rev: v3.6.2
100100
hooks:
101101
- id: prettier
102102
# https://prettier.io/docs/en/options.html#print-width

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ export SPHINX_MOCK_REQUIREMENTS=1
77
# install only Lightning Trainer packages
88
export PACKAGE_NAME=pytorch
99

10+
11+
# In Lightning Studio, the `lightning` package comes pre-installed.
12+
# Uninstall it first to ensure the editable install works correctly.
1013
setup:
14+
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1115
uv pip install -r requirements.txt \
1216
-r requirements/pytorch/base.txt \
1317
-r requirements/pytorch/test.txt \

docs/source-pytorch/deploy/production_advanced_2.rst

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@ Deploy models into production (advanced)
77

88
----
99

10-
*********************************
11-
Compile your model to TorchScript
12-
*********************************
13-
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
14-
The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you
15-
can save or directly use.
10+
************************************
11+
Export your model with torch.export
12+
************************************
13+
14+
`torch.export <https://pytorch.org/docs/stable/export.html>`_ is the recommended way to capture PyTorch models for
15+
deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees,
16+
making models suitable for inference optimization and cross-platform deployment.
17+
You can export any ``LightningModule`` using the ``torch.export.export()`` API.
1618

1719
.. testcode:: python
1820

21+
import torch
22+
from torch.export import export
23+
1924
class SimpleModel(LightningModule):
2025
def __init__(self):
2126
super().__init__()
@@ -25,25 +30,27 @@ can save or directly use.
2530
return torch.relu(self.l1(x.view(x.size(0), -1)))
2631

2732

28-
# create the model
33+
# create the model and example input
2934
model = SimpleModel()
30-
script = model.to_torchscript()
35+
example_input = torch.randn(1, 64)
3136

32-
# save for use in production environment
33-
torch.jit.save(script, "model.pt")
37+
# export the model
38+
exported_program = export(model, (example_input,))
3439

35-
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
40+
# save for use in production environment
41+
torch.export.save(exported_program, "model.pt2")
3642

37-
Once you have the exported model, you can run it in PyTorch or C++ runtime:
43+
It is recommended that you install the latest supported version of PyTorch to use this feature without
44+
limitations. Once you have the exported model, you can load and run it:
3845

3946
.. code-block:: python
4047
4148
inp = torch.rand(1, 64)
42-
scripted_module = torch.jit.load("model.pt")
43-
output = scripted_module(inp)
49+
loaded_program = torch.export.load("model.pt2")
50+
output = loaded_program.module()(inp)
4451
4552
46-
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
53+
For more complex models, you can also export specific methods by creating a wrapper:
4754

4855
.. code-block:: python
4956
@@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func
5461
self.dropout = nn.Dropout()
5562
self.mc_iteration = mc_iteration
5663
57-
@torch.jit.export
5864
def predict_step(self, batch, batch_idx):
5965
# enable Monte Carlo Dropout
6066
self.dropout.train()
@@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func
6672
6773
6874
model = LitMCdropoutModel(...)
69-
script = model.to_torchscript(file_path="model.pt", method="script")
75+
example_batch = torch.randn(32, 10) # example input
76+
77+
# Export the predict_step method
78+
exported_program = torch.export.export(
79+
lambda batch, idx: model.predict_step(batch, idx),
80+
(example_batch, 0)
81+
)
82+
torch.export.save(exported_program, "mc_dropout_model.pt2")

requirements/fabric/test.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
coverage ==7.10.6
1+
coverage ==7.10.7
22
numpy >=1.21.0, <1.27.0
33
pytest ==8.4.2
44
pytest-cov ==6.3.0
55
pytest-timeout ==2.4.0
6-
pytest-rerunfailures ==16.0.1
6+
pytest-rerunfailures ==16.0.1; python_version < "3.10"
7+
pytest-rerunfailures ==16.1; python_version >= "3.10"
78
pytest-random-order ==1.2.0
89
click ==8.1.8; python_version < "3.11"
910
click ==8.3.0; python_version > "3.10"

requirements/pytorch/test.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
coverage ==7.10.6
1+
coverage ==7.10.7
22
pytest ==8.4.2
33
pytest-cov ==6.3.0
44
pytest-timeout ==2.4.0
5-
pytest-rerunfailures ==16.0.1
5+
pytest-rerunfailures ==16.0.1; python_version < "3.10"
6+
pytest-rerunfailures ==16.1; python_version >= "3.10"
67
pytest-random-order ==1.2.0
78

89
# needed in tests

0 commit comments

Comments
 (0)