Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))


- Fix support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034))

---

## [2.5.2] - 2025-06-20
Expand Down
18 changes: 17 additions & 1 deletion src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.utilities.model_helpers import _ModuleMode
Expand Down Expand Up @@ -227,7 +228,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
self._layer_summary = self.summarize()
# 1 byte -> 8 bits
# TODO: how do we compute precision_megabytes in case of mixed precision?
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
precision_to_bits = {
"64": 64,
"32": 32,
"16": 16,
"bf16": 16,
"16-true": 16,
"bf16-true": 16,
"32-true": 32,
"64-true": 64,
}
if self._model._trainer and self._model.trainer.precision not in precision_to_bits:
rank_zero_warn(
f"Precision {self._model.trainer.precision} is not supported by the model summary. "
" Estimated model size in MB will not be accurate. Using 32 bits instead.",
category=UserWarning,
)
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
self._precision_megabytes = (precision / 8.0) * 1e-6

Expand Down
22 changes: 18 additions & 4 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,33 @@ def test_empty_model_size(max_depth):
pytest.param("mps", marks=RunIf(mps=True)),
],
)
def test_model_size_precision(tmp_path, accelerator):
"""Test model size for half and full precision."""
model = PreCalculatedModel()
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
def test_model_size_precision(tmp_path, accelerator, precision):
"""Test model size for different precision types."""
model = PreCalculatedModel(precision=int(precision.split("-")[0]))

# fit model
trainer = Trainer(
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision
)
trainer.fit(model)
summary = summarize(model)
assert model.pre_calculated_model_size == summary.model_size


def test_model_size_warning_on_unsupported_precision(tmp_path):
"""Test that a warning is raised when the precision is not supported."""
model = PreCalculatedModel(precision=32) # fallback to 32 bits

# supported precision by lightning but not by the model summary
trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path)
trainer.fit(model)

with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"):
summary = summarize(model)
assert model.pre_calculated_model_size == summary.model_size


def test_lazy_model_summary():
"""Test that the model summary can work with lazy layers."""
lazy_model = LazyModel()
Expand Down
Loading