Skip to content
6 changes: 6 additions & 0 deletions src/lightning/pytorch/callbacks/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
model_size = model_summary.model_size
total_training_modes = model_summary.total_training_modes

# todo Add `total_flops` in DeepSpeedSummary.
total_flops = model_summary.total_flops if hasattr(model_summary, "total_flops") else 0

if trainer.is_global_zero:
self.summarize(
summary_data,
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
total_flops=total_flops,
**self._summarize_kwargs,
)

Expand All @@ -92,13 +96,15 @@ def summarize(
trainable_parameters: int,
model_size: float,
total_training_modes: dict[str, int],
total_flops: int,
**summarize_kwargs: Any,
) -> None:
summary_table = _format_summary_table(
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
total_flops,
*summary_data,
)
log.info("\n" + summary_table)
3 changes: 3 additions & 0 deletions src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def summarize(
trainable_parameters: int,
model_size: float,
total_training_modes: dict[str, int],
total_flops: int,
**summarize_kwargs: Any,
) -> None:
from rich import get_console
Expand All @@ -86,6 +87,7 @@ def summarize(
table.add_column("Type")
table.add_column("Params", justify="right")
table.add_column("Mode")
table.add_column("FLOPs", justify="right")

column_names = list(zip(*summary_data))[0]

Expand Down Expand Up @@ -113,5 +115,6 @@ def summarize(
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")
grid.add_row(f"[bold]Total FLOPs[/]: {get_human_readable_count(total_flops)}")

console.print(grid)
76 changes: 63 additions & 13 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.flop_counter import FlopCounterMode
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache

Expand Down Expand Up @@ -180,29 +182,31 @@ class ModelSummary:
...
>>> model = LitModel()
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | Mode | In sizes | Out sizes
--------------------------------------------------------------------
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
--------------------------------------------------------------------
| Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
----------------------------------------------------------------------------
0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
----------------------------------------------------------------------------
132 K Trainable params
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode
2.6 M Total Flops
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
----------------------------------------------------------------------
| Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
------------------------------------------------------------------------------
0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | train | 2.6 M | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | train | 0 | [10, 512] | [10, 512]
------------------------------------------------------------------------------
132 K Trainable params
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode
2.6 M Total Flops

"""

Expand All @@ -212,6 +216,13 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")

# The max-depth needs to be plus one because the root module is already counted as depth 0.
self._flop_counter = FlopCounterMode(
mods=None if _TORCH_GREATER_EQUAL_2_4 else self._model,
display=False,
depth=max_depth + 1,
)

self._max_depth = max_depth
self._layer_summary = self.summarize()
# 1 byte -> 8 bits
Expand Down Expand Up @@ -279,6 +290,22 @@ def total_layer_params(self) -> int:
def model_size(self) -> float:
return self.total_parameters * self._precision_megabytes

@property
def total_flops(self) -> int:
return self._flop_counter.get_total_flops()

@property
def flop_counts(self) -> dict[str, dict[Any, int]]:
flop_counts = self._flop_counter.get_flop_counts()
ret = {
name: flop_counts.get(
f"{type(self._model).__name__}.{name}",
{},
)
for name in self.layer_names
}
return ret

def summarize(self) -> dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
if self._model.example_input_array is not None:
Expand Down Expand Up @@ -307,8 +334,18 @@ def _forward_example_input(self) -> None:
mode.capture(model)
model.eval()

# FlopCounterMode does not support ScriptModules before torch 2.4.0, so we use a null context
flop_context = (
contextlib.nullcontext()
if (
not _TORCH_GREATER_EQUAL_2_4
and any(isinstance(m, torch.jit.ScriptModule) for m in self._model.modules())
)
else self._flop_counter
)

forward_context = contextlib.nullcontext() if trainer is None else trainer.precision_plugin.forward_context()
with torch.no_grad(), forward_context:
with torch.no_grad(), forward_context, flop_context:
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
model(*input_)
Expand All @@ -330,6 +367,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
("Type", self.layer_types),
("Params", list(map(get_human_readable_count, self.param_nums))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
]
if self._model.example_input_array is not None:
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
Expand All @@ -349,6 +387,7 @@ def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], t
layer_summaries["Type"].append(NOT_APPLICABLE)
layer_summaries["Params"].append(get_human_readable_count(total_leftover_params))
layer_summaries["Mode"].append(NOT_APPLICABLE)
layer_summaries["FLOPs"].append(NOT_APPLICABLE)
if "In sizes" in layer_summaries:
layer_summaries["In sizes"].append(NOT_APPLICABLE)
if "Out sizes" in layer_summaries:
Expand All @@ -361,8 +400,16 @@ def __str__(self) -> str:
trainable_parameters = self.trainable_parameters
model_size = self.model_size
total_training_modes = self.total_training_modes

return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)
total_flops = self.total_flops

return _format_summary_table(
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
total_flops,
*arrays,
)

def __repr__(self) -> str:
return str(self)
Expand All @@ -383,6 +430,7 @@ def _format_summary_table(
trainable_parameters: int,
model_size: float,
total_training_modes: dict[str, int],
total_flops: int,
*cols: tuple[str, list[str]],
) -> str:
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
Expand Down Expand Up @@ -423,6 +471,8 @@ def _format_summary_table(
summary += "Modules in train mode"
summary += "\n" + s.format(total_training_modes["eval"], 10)
summary += "Modules in eval mode"
summary += "\n" + s.format(get_human_readable_count(total_flops), 10)
summary += "Total Flops"

return summary

Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def summarize(
assert summary_data[4][0] == "Mode"
assert summary_data[4][1][0] == "train"

assert summary_data[5][0] == "FLOPs"
assert all(isinstance(x, str) for x in summary_data[5][1])

assert total_training_modes == {"train": 1, "eval": 0}

model = BoringModel()
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ def example_input_array(self) -> Any:
trainable_parameters=1,
model_size=1,
total_training_modes=summary.total_training_modes,
total_flops=1,
)

# ensure that summary was logged + the breakdown of model parameters
assert mock_console.call_count == 2
# assert that the input summary data was converted correctly
args, _ = mock_table_add_row.call_args_list[0]
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "[4, 32]", "[4, 2]")
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]")
1 change: 1 addition & 0 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def test_empty_model_summary_shapes(max_depth):
assert summary.in_sizes == []
assert summary.out_sizes == []
assert summary.param_nums == []
assert summary.total_flops == 0


@pytest.mark.parametrize("max_depth", [-1, 1])
Expand Down
Loading