Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016))


- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))


- Fix support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034))


- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))

---
Expand Down
18 changes: 17 additions & 1 deletion src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.utilities.model_helpers import _ModuleMode
Expand Down Expand Up @@ -227,7 +228,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
self._layer_summary = self.summarize()
# 1 byte -> 8 bits
# TODO: how do we compute precision_megabytes in case of mixed precision?
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
precision_to_bits = {
"64": 64,
"32": 32,
"16": 16,
"bf16": 16,
"16-true": 16,
"bf16-true": 16,
"32-true": 32,
"64-true": 64,
}
if self._model._trainer and self._model.trainer.precision not in precision_to_bits:
rank_zero_warn(
f"Precision {self._model.trainer.precision} is not supported by the model summary. "
" Estimated model size in MB will not be accurate. Using 32 bits instead.",
category=UserWarning,
)
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
self._precision_megabytes = (precision / 8.0) * 1e-6

Expand Down
22 changes: 18 additions & 4 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,33 @@ def test_empty_model_size(max_depth):
pytest.param("mps", marks=RunIf(mps=True)),
],
)
def test_model_size_precision(tmp_path, accelerator):
"""Test model size for half and full precision."""
model = PreCalculatedModel()
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
def test_model_size_precision(tmp_path, accelerator, precision):
"""Test model size for different precision types."""
model = PreCalculatedModel(precision=int(precision.split("-")[0]))

# fit model
trainer = Trainer(
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision
)
trainer.fit(model)
summary = summarize(model)
assert model.pre_calculated_model_size == summary.model_size


def test_model_size_warning_on_unsupported_precision(tmp_path):
"""Test that a warning is raised when the precision is not supported."""
model = PreCalculatedModel(precision=32) # fallback to 32 bits

# supported precision by lightning but not by the model summary
trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path)
trainer.fit(model)

with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"):
summary = summarize(model)
assert model.pre_calculated_model_size == summary.model_size


def test_lazy_model_summary():
"""Test that the model summary can work with lazy layers."""
lazy_model = LazyModel()
Expand Down
Loading