diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 03f50d65bf1e9..ee9ff2f3bd902 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -68,6 +68,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - model_size = model_summary.model_size total_training_modes = model_summary.total_training_modes + # todo Add `total_flops` in DeepSpeedSummary. + total_flops = model_summary.total_flops if hasattr(model_summary, "total_flops") else 0 + if trainer.is_global_zero: self.summarize( summary_data, @@ -75,6 +78,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - trainable_parameters, model_size, total_training_modes, + total_flops=total_flops, **self._summarize_kwargs, ) @@ -92,6 +96,7 @@ def summarize( trainable_parameters: int, model_size: float, total_training_modes: dict[str, int], + total_flops: int, **summarize_kwargs: Any, ) -> None: summary_table = _format_summary_table( @@ -99,6 +104,7 @@ def summarize( trainable_parameters, model_size, total_training_modes, + total_flops, *summary_data, ) log.info("\n" + summary_table) diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index e4027f0dedcb1..817aeeb655a7a 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -72,6 +72,7 @@ def summarize( trainable_parameters: int, model_size: float, total_training_modes: dict[str, int], + total_flops: int, **summarize_kwargs: Any, ) -> None: from rich import get_console @@ -86,6 +87,7 @@ def summarize( table.add_column("Type") table.add_column("Params", justify="right") table.add_column("Mode") + table.add_column("FLOPs", justify="right") column_names = list(zip(*summary_data))[0] @@ -113,5 +115,6 @@ def summarize( grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}") grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}") + grid.add_row(f"[bold]Total FLOPs[/]: {get_human_readable_count(total_flops)}") console.print(grid) diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 6a5baf2c1e04a..98d74ff63ea5f 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -22,10 +22,12 @@ import torch import torch.nn as nn from torch import Tensor +from torch.utils.flop_counter import FlopCounterMode from torch.utils.hooks import RemovableHandle import lightning.pytorch as pl from lightning.fabric.utilities.distributed import _is_dtensor +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.utilities.model_helpers import _ModuleMode from lightning.pytorch.utilities.rank_zero import WarningCache @@ -180,29 +182,31 @@ class ModelSummary: ... >>> model = LitModel() >>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE - | Name | Type | Params | Mode | In sizes | Out sizes - -------------------------------------------------------------------- - 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512] - -------------------------------------------------------------------- + | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes + ---------------------------------------------------------------------------- + 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512] + ---------------------------------------------------------------------------- 132 K Trainable params 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) 3 Modules in train mode 0 Modules in eval mode + 2.6 M Total Flops >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE - | Name | Type | Params | Mode | In sizes | Out sizes - ---------------------------------------------------------------------- - 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512] - 1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512] - 2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512] - ---------------------------------------------------------------------- + | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes + ------------------------------------------------------------------------------ + 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | train | 2.6 M | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | train | 0 | [10, 512] | [10, 512] + ------------------------------------------------------------------------------ 132 K Trainable params 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) 3 Modules in train mode 0 Modules in eval mode + 2.6 M Total Flops """ @@ -212,6 +216,13 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: if not isinstance(max_depth, int) or max_depth < -1: raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.") + # The max-depth needs to be plus one because the root module is already counted as depth 0. + self._flop_counter = FlopCounterMode( + mods=None if _TORCH_GREATER_EQUAL_2_4 else self._model, + display=False, + depth=max_depth + 1, + ) + self._max_depth = max_depth self._layer_summary = self.summarize() # 1 byte -> 8 bits @@ -279,6 +290,22 @@ def total_layer_params(self) -> int: def model_size(self) -> float: return self.total_parameters * self._precision_megabytes + @property + def total_flops(self) -> int: + return self._flop_counter.get_total_flops() + + @property + def flop_counts(self) -> dict[str, dict[Any, int]]: + flop_counts = self._flop_counter.get_flop_counts() + ret = { + name: flop_counts.get( + f"{type(self._model).__name__}.{name}", + {}, + ) + for name in self.layer_names + } + return ret + def summarize(self) -> dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: @@ -307,8 +334,18 @@ def _forward_example_input(self) -> None: mode.capture(model) model.eval() + # FlopCounterMode does not support ScriptModules before torch 2.4.0, so we use a null context + flop_context = ( + contextlib.nullcontext() + if ( + not _TORCH_GREATER_EQUAL_2_4 + and any(isinstance(m, torch.jit.ScriptModule) for m in self._model.modules()) + ) + else self._flop_counter + ) + forward_context = contextlib.nullcontext() if trainer is None else trainer.precision_plugin.forward_context() - with torch.no_grad(), forward_context: + with torch.no_grad(), forward_context, flop_context: # let the model hooks collect the input- and output shapes if isinstance(input_, (list, tuple)): model(*input_) @@ -330,6 +367,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]: ("Type", self.layer_types), ("Params", list(map(get_human_readable_count, self.param_nums))), ("Mode", ["train" if mode else "eval" for mode in self.training_modes]), + ("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))), ] if self._model.example_input_array is not None: arrays.append(("In sizes", [str(x) for x in self.in_sizes])) @@ -349,6 +387,7 @@ def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], t layer_summaries["Type"].append(NOT_APPLICABLE) layer_summaries["Params"].append(get_human_readable_count(total_leftover_params)) layer_summaries["Mode"].append(NOT_APPLICABLE) + layer_summaries["FLOPs"].append(NOT_APPLICABLE) if "In sizes" in layer_summaries: layer_summaries["In sizes"].append(NOT_APPLICABLE) if "Out sizes" in layer_summaries: @@ -361,8 +400,16 @@ def __str__(self) -> str: trainable_parameters = self.trainable_parameters model_size = self.model_size total_training_modes = self.total_training_modes - - return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays) + total_flops = self.total_flops + + return _format_summary_table( + total_parameters, + trainable_parameters, + model_size, + total_training_modes, + total_flops, + *arrays, + ) def __repr__(self) -> str: return str(self) @@ -383,6 +430,7 @@ def _format_summary_table( trainable_parameters: int, model_size: float, total_training_modes: dict[str, int], + total_flops: int, *cols: tuple[str, list[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -423,6 +471,8 @@ def _format_summary_table( summary += "Modules in train mode" summary += "\n" + s.format(total_training_modes["eval"], 10) summary += "Modules in eval mode" + summary += "\n" + s.format(get_human_readable_count(total_flops), 10) + summary += "Total Flops" return summary diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py index 215176ee2376b..07676801fcfb2 100644 --- a/tests/tests_pytorch/callbacks/test_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_model_summary.py @@ -65,6 +65,9 @@ def summarize( assert summary_data[4][0] == "Mode" assert summary_data[4][1][0] == "train" + assert summary_data[5][0] == "FLOPs" + assert all(isinstance(x, str) for x in summary_data[5][1]) + assert total_training_modes == {"train": 1, "eval": 0} model = BoringModel() diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index 7534c23d5679c..af385bb1a9b39 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -62,10 +62,11 @@ def example_input_array(self) -> Any: trainable_parameters=1, model_size=1, total_training_modes=summary.total_training_modes, + total_flops=1, ) # ensure that summary was logged + the breakdown of model parameters assert mock_console.call_count == 2 # assert that the input summary data was converted correctly args, _ = mock_table_add_row.call_args_list[0] - assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "[4, 32]", "[4, 2]") + assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]") diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index 54c5572d01767..85825b5ea749d 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -173,6 +173,7 @@ def test_empty_model_summary_shapes(max_depth): assert summary.in_sizes == [] assert summary.out_sizes == [] assert summary.param_nums == [] + assert summary.total_flops == 0 @pytest.mark.parametrize("max_depth", [-1, 1])