Skip to content

Commit 6496a71

Browse files
authored
Merge branch 'master' into docs_finetuning_callback_example
2 parents 98c58b0 + c943c05 commit 6496a71

File tree

11 files changed

+164
-30
lines changed

11 files changed

+164
-30
lines changed

docs/source-pytorch/deploy/production_advanced_2.rst

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@ Deploy models into production (advanced)
77

88
----
99

10-
*********************************
11-
Compile your model to TorchScript
12-
*********************************
13-
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
14-
The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you
15-
can save or directly use.
10+
************************************
11+
Export your model with torch.export
12+
************************************
13+
14+
`torch.export <https://pytorch.org/docs/stable/export.html>`_ is the recommended way to capture PyTorch models for
15+
deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees,
16+
making models suitable for inference optimization and cross-platform deployment.
17+
You can export any ``LightningModule`` using the ``torch.export.export()`` API.
1618

1719
.. testcode:: python
1820

21+
import torch
22+
from torch.export import export
23+
1924
class SimpleModel(LightningModule):
2025
def __init__(self):
2126
super().__init__()
@@ -25,25 +30,27 @@ can save or directly use.
2530
return torch.relu(self.l1(x.view(x.size(0), -1)))
2631

2732

28-
# create the model
33+
# create the model and example input
2934
model = SimpleModel()
30-
script = model.to_torchscript()
35+
example_input = torch.randn(1, 64)
3136

32-
# save for use in production environment
33-
torch.jit.save(script, "model.pt")
37+
# export the model
38+
exported_program = export(model, (example_input,))
3439

35-
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
40+
# save for use in production environment
41+
torch.export.save(exported_program, "model.pt2")
3642

37-
Once you have the exported model, you can run it in PyTorch or C++ runtime:
43+
It is recommended that you install the latest supported version of PyTorch to use this feature without
44+
limitations. Once you have the exported model, you can load and run it:
3845

3946
.. code-block:: python
4047
4148
inp = torch.rand(1, 64)
42-
scripted_module = torch.jit.load("model.pt")
43-
output = scripted_module(inp)
49+
loaded_program = torch.export.load("model.pt2")
50+
output = loaded_program.module()(inp)
4451
4552
46-
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
53+
For more complex models, you can also export specific methods by creating a wrapper:
4754

4855
.. code-block:: python
4956
@@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func
5461
self.dropout = nn.Dropout()
5562
self.mc_iteration = mc_iteration
5663
57-
@torch.jit.export
5864
def predict_step(self, batch, batch_idx):
5965
# enable Monte Carlo Dropout
6066
self.dropout.train()
@@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func
6672
6773
6874
model = LitMCdropoutModel(...)
69-
script = model.to_torchscript(file_path="model.pt", method="script")
75+
example_batch = torch.randn(32, 10) # example input
76+
77+
# Export the predict_step method
78+
exported_program = torch.export.export(
79+
lambda batch, idx: model.predict_step(batch, idx),
80+
(example_batch, 0)
81+
)
82+
torch.export.save(exported_program, "mc_dropout_model.pt2")

requirements/fabric/strategies.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55

66
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
77
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
8-
deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict
8+
deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict
99
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"

requirements/pytorch/strategies.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
55
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
6-
deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict
6+
deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.fabric.strategies.registry import _StrategyRegistry
3838
from lightning.fabric.strategies.strategy import _Sharded
3939
from lightning.fabric.utilities.distributed import log
40+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
4041
from lightning.fabric.utilities.load import _move_state_into
4142
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
4243
from lightning.fabric.utilities.seed import reset_seed
@@ -47,6 +48,7 @@
4748
from torch.optim.lr_scheduler import _LRScheduler
4849

4950
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
51+
_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
5052

5153

5254
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@@ -239,6 +241,19 @@ def __init__(
239241
" Install it by running `pip install -U deepspeed`."
240242
)
241243

244+
if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
245+
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
246+
# DeepSpeed added support for this behavior in version 0.16.0.
247+
import deepspeed
248+
249+
deepspeed_version = deepspeed.__version__
250+
251+
raise ImportError(
252+
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
253+
f"Detected DeepSpeed version: {deepspeed_version}. "
254+
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
255+
)
256+
242257
super().__init__(
243258
accelerator=accelerator,
244259
parallel_devices=parallel_devices,

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3737
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
3838
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
39+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
3940
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
4041
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
5353

5454

55+
- Fixed `ModelPruning` sparsity logging bug that caused incorrect sparsity percentages ([#21223](https://github.com/Lightning-AI/pytorch-lightning/pull/21223))
56+
57+
5558
- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))
5659

5760

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None:
349349
def _log_sparsity_stats(
350350
self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0
351351
) -> None:
352-
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
352+
total_params = sum(total for _, total in curr)
353353
prev_total_zeros = sum(zeros for zeros, _ in prev)
354354
curr_total_zeros = sum(zeros for zeros, _ in curr)
355355
log.info(

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
from lightning.fabric.strategies import _StrategyRegistry
3636
from lightning.fabric.strategies.deepspeed import (
3737
_DEEPSPEED_AVAILABLE,
38+
_DEEPSPEED_GREATER_EQUAL_0_16,
3839
_format_precision_config,
3940
_validate_checkpoint_directory,
4041
_validate_device_index_selection,
4142
)
43+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
4244
from lightning.fabric.utilities.optimizer import _optimizers_to_device
4345
from lightning.fabric.utilities.seed import reset_seed
4446
from lightning.fabric.utilities.types import _PATH
@@ -262,6 +264,19 @@ def __init__(
262264
" Install it by running `pip install -U deepspeed`."
263265
)
264266

267+
if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
268+
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
269+
# DeepSpeed added support for this behavior in version 0.16.0.
270+
import deepspeed
271+
272+
deepspeed_version = deepspeed.__version__
273+
274+
raise ImportError(
275+
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
276+
f"Detected DeepSpeed version: {deepspeed_version}. "
277+
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
278+
)
279+
265280
super().__init__(
266281
accelerator=accelerator,
267282
parallel_devices=parallel_devices,

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def _generate_sync_fn(self) -> None:
9191
"""Used to compute the syncing function and cache it."""
9292
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
9393
# save the function as `_fn` as the meta are being re-created and the object references need to match.
94-
# ignore typing, bad support for `partial`: mypy/issues/1484
95-
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore[unused-ignore]
94+
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group)
9695

9796
@property
9897
def __call__(self) -> Any:

src/lightning/pytorch/utilities/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
9393
]
9494
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
9595
optim_files = get_optim_files(checkpoint_dir)
96-
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
96+
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
9797
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
9898
model_file = get_model_state_file(checkpoint_dir, zero_stage)
99-
client_state = torch.load(model_file, map_location=CPU_DEVICE)
99+
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)
100100
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
101101
# State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
102102
# Lightning version < 2.1. Delete the `_forward_module` prefix before saving.

0 commit comments

Comments
 (0)