diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index bff92f8afc672..b18b2d51ccf2c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,9 +27,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) + + - fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809)) +- Fix support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034)) + + - Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032)) --- diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 98d74ff63ea5f..01b692abdc05f 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -26,6 +26,7 @@ from torch.utils.hooks import RemovableHandle import lightning.pytorch as pl +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.distributed import _is_dtensor from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.utilities.model_helpers import _ModuleMode @@ -227,7 +228,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._layer_summary = self.summarize() # 1 byte -> 8 bits # TODO: how do we compute precision_megabytes in case of mixed precision? - precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16} + precision_to_bits = { + "64": 64, + "32": 32, + "16": 16, + "bf16": 16, + "16-true": 16, + "bf16-true": 16, + "32-true": 32, + "64-true": 64, + } + if self._model._trainer and self._model.trainer.precision not in precision_to_bits: + rank_zero_warn( + f"Precision {self._model.trainer.precision} is not supported by the model summary. " + " Estimated model size in MB will not be accurate. Using 32 bits instead.", + category=UserWarning, + ) precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32 self._precision_megabytes = (precision / 8.0) * 1e-6 diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index 85825b5ea749d..35edf78fa7081 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -324,19 +324,33 @@ def test_empty_model_size(max_depth): pytest.param("mps", marks=RunIf(mps=True)), ], ) -def test_model_size_precision(tmp_path, accelerator): - """Test model size for half and full precision.""" - model = PreCalculatedModel() +@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"]) +def test_model_size_precision(tmp_path, accelerator, precision): + """Test model size for different precision types.""" + model = PreCalculatedModel(precision=int(precision.split("-")[0])) # fit model trainer = Trainer( - default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32 + default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision ) trainer.fit(model) summary = summarize(model) assert model.pre_calculated_model_size == summary.model_size +def test_model_size_warning_on_unsupported_precision(tmp_path): + """Test that a warning is raised when the precision is not supported.""" + model = PreCalculatedModel(precision=32) # fallback to 32 bits + + # supported precision by lightning but not by the model summary + trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path) + trainer.fit(model) + + with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"): + summary = summarize(model) + assert model.pre_calculated_model_size == summary.model_size + + def test_lazy_model_summary(): """Test that the model summary can work with lazy layers.""" lazy_model = LazyModel()