Skip to content

Commit 5a0c1f7

Browse files
committed
add support for more precision + add warning on unsupported
1 parent 25b1343 commit 5a0c1f7

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch.utils.hooks import RemovableHandle
2727

2828
import lightning.pytorch as pl
29+
from lightning.fabric.utilities import rank_zero_warn
2930
from lightning.fabric.utilities.distributed import _is_dtensor
3031
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
3132
from lightning.pytorch.utilities.model_helpers import _ModuleMode
@@ -227,7 +228,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
227228
self._layer_summary = self.summarize()
228229
# 1 byte -> 8 bits
229230
# TODO: how do we compute precision_megabytes in case of mixed precision?
230-
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
231+
precision_to_bits = {
232+
"64": 64,
233+
"32": 32,
234+
"16": 16,
235+
"bf16": 16,
236+
"16-true": 16,
237+
"bf16-true": 16,
238+
"32-true": 32,
239+
"64-true": 64,
240+
}
241+
if self._model.trainer is not None and self._model.trainer.precision not in precision_to_bits:
242+
rank_zero_warn(
243+
f"Precision {self._model.trainer.precision} is not supported by the model summary. "
244+
" Estimated model size in MB will not be accurate. Using 32 bits instead.",
245+
category=UserWarning,
246+
)
231247
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
232248
self._precision_megabytes = (precision / 8.0) * 1e-6
233249

0 commit comments

Comments
 (0)