Skip to content

Commit 24dc9af

Browse files
SkafteNickiBorda
authored andcommitted
Add support for more dtypes in ModelSummary and warning on non-supported (#21034)
(cherry picked from commit 1d8cf20)
1 parent 2685947 commit 24dc9af

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Fix double iteration bug when resumed from a checkpoint. ([#20775](https://github.com/Lightning-AI/pytorch-lightning/pull/20775))
3636

3737

38+
- Fix support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034))
39+
40+
3841
- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))
3942

4043

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.utils.hooks import RemovableHandle
2626

2727
import lightning.pytorch as pl
28+
from lightning.fabric.utilities import rank_zero_warn
2829
from lightning.fabric.utilities.distributed import _is_dtensor
2930
from lightning.pytorch.utilities.model_helpers import _ModuleMode
3031
from lightning.pytorch.utilities.rank_zero import WarningCache
@@ -216,7 +217,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
216217
self._layer_summary = self.summarize()
217218
# 1 byte -> 8 bits
218219
# TODO: how do we compute precision_megabytes in case of mixed precision?
219-
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
220+
precision_to_bits = {
221+
"64": 64,
222+
"32": 32,
223+
"16": 16,
224+
"bf16": 16,
225+
"16-true": 16,
226+
"bf16-true": 16,
227+
"32-true": 32,
228+
"64-true": 64,
229+
}
230+
if self._model._trainer and self._model.trainer.precision not in precision_to_bits:
231+
rank_zero_warn(
232+
f"Precision {self._model.trainer.precision} is not supported by the model summary. "
233+
" Estimated model size in MB will not be accurate. Using 32 bits instead.",
234+
category=UserWarning,
235+
)
220236
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
221237
self._precision_megabytes = (precision / 8.0) * 1e-6
222238

tests/tests_pytorch/utilities/test_model_summary.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,33 @@ def test_empty_model_size(max_depth):
323323
pytest.param("mps", marks=RunIf(mps=True)),
324324
],
325325
)
326-
def test_model_size_precision(tmp_path, accelerator):
327-
"""Test model size for half and full precision."""
328-
model = PreCalculatedModel()
326+
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
327+
def test_model_size_precision(tmp_path, accelerator, precision):
328+
"""Test model size for different precision types."""
329+
model = PreCalculatedModel(precision=int(precision.split("-")[0]))
329330

330331
# fit model
331332
trainer = Trainer(
332-
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
333+
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision
333334
)
334335
trainer.fit(model)
335336
summary = summarize(model)
336337
assert model.pre_calculated_model_size == summary.model_size
337338

338339

340+
def test_model_size_warning_on_unsupported_precision(tmp_path):
341+
"""Test that a warning is raised when the precision is not supported."""
342+
model = PreCalculatedModel(precision=32) # fallback to 32 bits
343+
344+
# supported precision by lightning but not by the model summary
345+
trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path)
346+
trainer.fit(model)
347+
348+
with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"):
349+
summary = summarize(model)
350+
assert model.pre_calculated_model_size == summary.model_size
351+
352+
339353
def test_lazy_model_summary():
340354
"""Test that the model summary can work with lazy layers."""
341355
lazy_model = LazyModel()

0 commit comments

Comments
 (0)