Skip to content

Commit 1d8cf20

Browse files
authored
Add support for more dtypes in ModelSummary and warning on non-supported (#21034)
1 parent 20960ec commit 1d8cf20

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
### Fixed
2828

2929
- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016))
30+
31+
3032
- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))
3133

3234

35+
- Fix support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034))
36+
37+
3338
- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))
3439

3540
---

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch.utils.hooks import RemovableHandle
2727

2828
import lightning.pytorch as pl
29+
from lightning.fabric.utilities import rank_zero_warn
2930
from lightning.fabric.utilities.distributed import _is_dtensor
3031
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
3132
from lightning.pytorch.utilities.model_helpers import _ModuleMode
@@ -227,7 +228,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
227228
self._layer_summary = self.summarize()
228229
# 1 byte -> 8 bits
229230
# TODO: how do we compute precision_megabytes in case of mixed precision?
230-
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
231+
precision_to_bits = {
232+
"64": 64,
233+
"32": 32,
234+
"16": 16,
235+
"bf16": 16,
236+
"16-true": 16,
237+
"bf16-true": 16,
238+
"32-true": 32,
239+
"64-true": 64,
240+
}
241+
if self._model._trainer and self._model.trainer.precision not in precision_to_bits:
242+
rank_zero_warn(
243+
f"Precision {self._model.trainer.precision} is not supported by the model summary. "
244+
" Estimated model size in MB will not be accurate. Using 32 bits instead.",
245+
category=UserWarning,
246+
)
231247
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
232248
self._precision_megabytes = (precision / 8.0) * 1e-6
233249

tests/tests_pytorch/utilities/test_model_summary.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,19 +324,33 @@ def test_empty_model_size(max_depth):
324324
pytest.param("mps", marks=RunIf(mps=True)),
325325
],
326326
)
327-
def test_model_size_precision(tmp_path, accelerator):
328-
"""Test model size for half and full precision."""
329-
model = PreCalculatedModel()
327+
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
328+
def test_model_size_precision(tmp_path, accelerator, precision):
329+
"""Test model size for different precision types."""
330+
model = PreCalculatedModel(precision=int(precision.split("-")[0]))
330331

331332
# fit model
332333
trainer = Trainer(
333-
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
334+
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision
334335
)
335336
trainer.fit(model)
336337
summary = summarize(model)
337338
assert model.pre_calculated_model_size == summary.model_size
338339

339340

341+
def test_model_size_warning_on_unsupported_precision(tmp_path):
342+
"""Test that a warning is raised when the precision is not supported."""
343+
model = PreCalculatedModel(precision=32) # fallback to 32 bits
344+
345+
# supported precision by lightning but not by the model summary
346+
trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path)
347+
trainer.fit(model)
348+
349+
with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"):
350+
summary = summarize(model)
351+
assert model.pre_calculated_model_size == summary.model_size
352+
353+
340354
def test_lazy_model_summary():
341355
"""Test that the model summary can work with lazy layers."""
342356
lazy_model = LazyModel()

0 commit comments

Comments
 (0)