Skip to content

Commit 5b583b9

Browse files
authored
Merge branch 'master' into dependabot-pip-requirements-pytest-cov-7.0.0
2 parents a4a26f0 + 7fa8fe4 commit 5b583b9

File tree

4 files changed

+46
-24
lines changed

4 files changed

+46
-24
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ export SPHINX_MOCK_REQUIREMENTS=1
77
# install only Lightning Trainer packages
88
export PACKAGE_NAME=pytorch
99

10+
11+
# In Lightning Studio, the `lightning` package comes pre-installed.
12+
# Uninstall it first to ensure the editable install works correctly.
1013
setup:
14+
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1115
uv pip install -r requirements.txt \
1216
-r requirements/pytorch/base.txt \
1317
-r requirements/pytorch/test.txt \

docs/source-pytorch/deploy/production_advanced_2.rst

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@ Deploy models into production (advanced)
77

88
----
99

10-
*********************************
11-
Compile your model to TorchScript
12-
*********************************
13-
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
14-
The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you
15-
can save or directly use.
10+
************************************
11+
Export your model with torch.export
12+
************************************
13+
14+
`torch.export <https://pytorch.org/docs/stable/export.html>`_ is the recommended way to capture PyTorch models for
15+
deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees,
16+
making models suitable for inference optimization and cross-platform deployment.
17+
You can export any ``LightningModule`` using the ``torch.export.export()`` API.
1618

1719
.. testcode:: python
1820

21+
import torch
22+
from torch.export import export
23+
1924
class SimpleModel(LightningModule):
2025
def __init__(self):
2126
super().__init__()
@@ -25,25 +30,27 @@ can save or directly use.
2530
return torch.relu(self.l1(x.view(x.size(0), -1)))
2631

2732

28-
# create the model
33+
# create the model and example input
2934
model = SimpleModel()
30-
script = model.to_torchscript()
35+
example_input = torch.randn(1, 64)
3136

32-
# save for use in production environment
33-
torch.jit.save(script, "model.pt")
37+
# export the model
38+
exported_program = export(model, (example_input,))
3439

35-
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
40+
# save for use in production environment
41+
torch.export.save(exported_program, "model.pt2")
3642

37-
Once you have the exported model, you can run it in PyTorch or C++ runtime:
43+
It is recommended that you install the latest supported version of PyTorch to use this feature without
44+
limitations. Once you have the exported model, you can load and run it:
3845

3946
.. code-block:: python
4047
4148
inp = torch.rand(1, 64)
42-
scripted_module = torch.jit.load("model.pt")
43-
output = scripted_module(inp)
49+
loaded_program = torch.export.load("model.pt2")
50+
output = loaded_program.module()(inp)
4451
4552
46-
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
53+
For more complex models, you can also export specific methods by creating a wrapper:
4754

4855
.. code-block:: python
4956
@@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func
5461
self.dropout = nn.Dropout()
5562
self.mc_iteration = mc_iteration
5663
57-
@torch.jit.export
5864
def predict_step(self, batch, batch_idx):
5965
# enable Monte Carlo Dropout
6066
self.dropout.train()
@@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func
6672
6773
6874
model = LitMCdropoutModel(...)
69-
script = model.to_torchscript(file_path="model.pt", method="script")
75+
example_batch = torch.randn(32, 10) # example input
76+
77+
# Export the predict_step method
78+
exported_program = torch.export.export(
79+
lambda batch, idx: model.predict_step(batch, idx),
80+
(example_batch, 0)
81+
)
82+
torch.export.save(exported_program, "mc_dropout_model.pt2")

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def _generate_sync_fn(self) -> None:
9191
"""Used to compute the syncing function and cache it."""
9292
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
9393
# save the function as `_fn` as the meta are being re-created and the object references need to match.
94-
# ignore typing, bad support for `partial`: mypy/issues/1484
95-
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore[unused-ignore]
94+
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group)
9695

9796
@property
9897
def __call__(self) -> Any:

src/lightning/pytorch/utilities/model_helpers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from lightning_utilities.core.imports import RequirementCache
2121
from torch import nn
22-
from typing_extensions import Concatenate, ParamSpec
22+
from typing_extensions import Concatenate, ParamSpec, override
2323

2424
import lightning.pytorch as pl
2525

@@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None:
104104
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method
105105

106106

107-
class _restricted_classmethod_impl(Generic[_T, _R_co, _P]):
107+
class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]):
108108
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
109109
instead of a class type."""
110110

111+
method: Callable[Concatenate[type[_T], _P], _R_co]
112+
111113
def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
114+
super().__init__(method)
112115
self.method = method
113116

114-
def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]:
117+
@override
118+
def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override]
115119
# The wrapper ensures that the method can be inspected, but not called on an instance
116120
@functools.wraps(self.method)
117121
def wrapper(*args: Any, **kwargs: Any) -> _R_co:
118122
# Workaround for https://github.com/pytorch/pytorch/issues/67146
119123
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
124+
cls_type = cls if cls is not None else type(instance)
120125
if instance is not None and not is_scripting:
121126
raise TypeError(
122-
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
127+
f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance."
123128
" Please call it on the class type and make sure the return value is used."
124129
)
125-
return self.method(cls, *args, **kwargs)
130+
return self.method(cls_type, *args, **kwargs)
126131

132+
wrapper.__func__ = self.method
127133
return wrapper
128134

129135

0 commit comments

Comments
 (0)