Skip to content

Commit ecfbb57

Browse files
Merge branch 'master' into fix/20947-checkpoint-manual-opt
2 parents 6e2fb03 + 7323bb8 commit ecfbb57

File tree

5 files changed

+129
-26
lines changed

5 files changed

+129
-26
lines changed

docs/source-pytorch/deploy/production_advanced_2.rst

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@ Deploy models into production (advanced)
77

88
----
99

10-
*********************************
11-
Compile your model to TorchScript
12-
*********************************
13-
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
14-
The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you
15-
can save or directly use.
10+
************************************
11+
Export your model with torch.export
12+
************************************
13+
14+
`torch.export <https://pytorch.org/docs/stable/export.html>`_ is the recommended way to capture PyTorch models for
15+
deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees,
16+
making models suitable for inference optimization and cross-platform deployment.
17+
You can export any ``LightningModule`` using the ``torch.export.export()`` API.
1618

1719
.. testcode:: python
1820

21+
import torch
22+
from torch.export import export
23+
1924
class SimpleModel(LightningModule):
2025
def __init__(self):
2126
super().__init__()
@@ -25,25 +30,27 @@ can save or directly use.
2530
return torch.relu(self.l1(x.view(x.size(0), -1)))
2631

2732

28-
# create the model
33+
# create the model and example input
2934
model = SimpleModel()
30-
script = model.to_torchscript()
35+
example_input = torch.randn(1, 64)
3136

32-
# save for use in production environment
33-
torch.jit.save(script, "model.pt")
37+
# export the model
38+
exported_program = export(model, (example_input,))
3439

35-
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
40+
# save for use in production environment
41+
torch.export.save(exported_program, "model.pt2")
3642

37-
Once you have the exported model, you can run it in PyTorch or C++ runtime:
43+
It is recommended that you install the latest supported version of PyTorch to use this feature without
44+
limitations. Once you have the exported model, you can load and run it:
3845

3946
.. code-block:: python
4047
4148
inp = torch.rand(1, 64)
42-
scripted_module = torch.jit.load("model.pt")
43-
output = scripted_module(inp)
49+
loaded_program = torch.export.load("model.pt2")
50+
output = loaded_program.module()(inp)
4451
4552
46-
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
53+
For more complex models, you can also export specific methods by creating a wrapper:
4754

4855
.. code-block:: python
4956
@@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func
5461
self.dropout = nn.Dropout()
5562
self.mc_iteration = mc_iteration
5663
57-
@torch.jit.export
5864
def predict_step(self, batch, batch_idx):
5965
# enable Monte Carlo Dropout
6066
self.dropout.train()
@@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func
6672
6773
6874
model = LitMCdropoutModel(...)
69-
script = model.to_torchscript(file_path="model.pt", method="script")
75+
example_batch = torch.randn(32, 10) # example input
76+
77+
# Export the predict_step method
78+
exported_program = torch.export.export(
79+
lambda batch, idx: model.predict_step(batch, idx),
80+
(example_batch, 0)
81+
)
82+
torch.export.save(exported_program, "mc_dropout_model.pt2")

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
5353

5454

55+
- Fixed `ModelPruning` sparsity logging bug that caused incorrect sparsity percentages ([#21223](https://github.com/Lightning-AI/pytorch-lightning/pull/21223))
56+
57+
5558
- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))
5659

5760

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None:
349349
def _log_sparsity_stats(
350350
self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0
351351
) -> None:
352-
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
352+
total_params = sum(total for _, total in curr)
353353
prev_total_zeros = sum(zeros for zeros, _ in prev)
354354
curr_total_zeros = sum(zeros for zeros, _ in curr)
355355
log.info(

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def _generate_sync_fn(self) -> None:
9191
"""Used to compute the syncing function and cache it."""
9292
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
9393
# save the function as `_fn` as the meta are being re-created and the object references need to match.
94-
# ignore typing, bad support for `partial`: mypy/issues/1484
95-
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore[unused-ignore]
94+
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group)
9695

9796
@property
9897
def __call__(self) -> Any:

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
262262
actual = [m for m in actual if m.startswith("Applied")]
263263
percentage = r"\(\d+(?:\.\d+)?%\)"
264264
expected = [
265-
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
265+
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
266266
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
267267
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
268-
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
268+
rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
269269
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
270270
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
271-
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
271+
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
272272
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
273273
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
274274
]
@@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
329329
actual = [m for m in actual if m.startswith("Applied")]
330330
percentage = r"\(\d+(?:\.\d+)?%\)"
331331
expected = [
332-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
333-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
334-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
332+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
333+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
334+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
335335
]
336336
expected = [re.compile(s) for s in expected]
337337
assert all(regex.match(s) for s, regex in zip(actual, expected))
@@ -463,3 +463,91 @@ def __init__(self):
463463
f"Actual weight_orig: {weight_orig}\n"
464464
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
465465
)
466+
467+
468+
@pytest.mark.parametrize("pruning_amount", [0.1, 0.2, 0.3, 0.5])
469+
@pytest.mark.parametrize("model_type", ["simple", "complex"])
470+
def test_sparsity_calculation(tmp_path, caplog, pruning_amount: float, model_type: str):
471+
"""Test that the sparsity calculation fix correctly reports percentages."""
472+
473+
class SimpleModel(BoringModel):
474+
"""Simple model with 66 parameters (64 weight + 2 bias)."""
475+
476+
def __init__(self):
477+
super().__init__()
478+
self.layer = nn.Linear(32, 2) # 32*2 + 2 = 66 params
479+
480+
class ComplexModel(BoringModel):
481+
"""Complex model with multiple layers."""
482+
483+
def __init__(self):
484+
super().__init__()
485+
self.layer1 = nn.Linear(32, 64) # 32*64 + 64 = 2112 params
486+
self.layer2 = nn.Linear(64, 2) # 64*2 + 2 = 130 params
487+
# Total: 2112 + 130 = 2242 params (but only layer1 will be pruned)
488+
# layer1 params: 2112
489+
490+
def forward(self, x):
491+
x = torch.relu(self.layer1(x))
492+
return self.layer2(x)
493+
494+
if model_type == "simple":
495+
model = SimpleModel()
496+
expected_total_params = 66
497+
parameters_to_prune = None
498+
else:
499+
model = ComplexModel()
500+
expected_total_params = 2112
501+
parameters_to_prune = [(model.layer1, "weight"), (model.layer1, "bias")]
502+
503+
pruning = ModelPruning(
504+
pruning_fn="l1_unstructured",
505+
parameters_to_prune=parameters_to_prune,
506+
amount=pruning_amount,
507+
verbose=1,
508+
use_global_unstructured=True,
509+
)
510+
511+
trainer = Trainer(
512+
default_root_dir=tmp_path,
513+
enable_progress_bar=False,
514+
enable_model_summary=False,
515+
enable_checkpointing=False,
516+
logger=False,
517+
limit_train_batches=1,
518+
max_epochs=1,
519+
accelerator="cpu",
520+
callbacks=[pruning],
521+
)
522+
523+
with caplog.at_level(INFO):
524+
trainer.fit(model)
525+
526+
sparsity_logs = [msg for msg in caplog.messages if "Applied `L1Unstructured`. Pruned:" in msg]
527+
assert len(sparsity_logs) == 1, f"Expected 1 sparsity log, got {len(sparsity_logs)}"
528+
sparsity_log = sparsity_logs[0]
529+
pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)"
530+
match = re.search(pattern, sparsity_log)
531+
assert match, f"Could not parse sparsity log: {sparsity_log}"
532+
533+
total_params_before = int(match.group(1))
534+
pruned_count = int(match.group(2))
535+
total_params_after = int(match.group(3))
536+
sparsity_percentage = float(match.group(4))
537+
assert total_params_before == expected_total_params, (
538+
f"Total parameter count mismatch for {model_type} model. "
539+
f"Expected {expected_total_params}, got {total_params_before}"
540+
)
541+
assert total_params_after == expected_total_params, (
542+
f"Total parameter count should be consistent. Before: {total_params_before}, After: {total_params_after}"
543+
)
544+
545+
# Verify sparsity percentage is approximately correct
546+
expected_sparsity = pruning_amount * 100
547+
tolerance = 5.0
548+
assert abs(sparsity_percentage - expected_sparsity) <= tolerance
549+
550+
# Verify the number of pruned parameters is reasonable
551+
expected_pruned_count = int(expected_total_params * pruning_amount)
552+
pruned_tolerance = max(1, int(expected_total_params * 0.05))
553+
assert abs(pruned_count - expected_pruned_count) <= pruned_tolerance

0 commit comments

Comments
 (0)