diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 01b692abdc05f..4760b76af9cb8 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -17,6 +17,7 @@ import logging import math from collections import OrderedDict +from enum import Enum from typing import Any, Optional, Union import torch @@ -41,6 +42,12 @@ NOT_APPLICABLE = "n/a" +class ModelSummaryTrainingMode(Enum): + TRAIN = "train" + EVAL = "eval" + FREEZE = "freeze" + + class LayerSummary: """Summary class for a single layer in a :class:`~lightning.pytorch.core.LightningModule`. It collects the following information: @@ -146,6 +153,13 @@ def training(self) -> bool: """Returns whether the module is in training mode.""" return self._module.training + @property + def requires_grad(self) -> bool: + """Returns whether the module requires grad.""" + if self.num_parameters > 0: + return any(param.requires_grad for name, param in self._module.named_parameters()) + return True + class ModelSummary: """Generates a summary of all layers in a :class:`~lightning.pytorch.core.LightningModule`. @@ -281,14 +295,32 @@ def param_nums(self) -> list[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property - def training_modes(self) -> list[bool]: - return [layer.training for layer in self._layer_summary.values()] + def training_modes(self) -> list[int]: + return [ + (ModelSummaryTrainingMode.TRAIN if layer.training else ModelSummaryTrainingMode.EVAL) + if layer.requires_grad + else ModelSummaryTrainingMode.FREEZE + for layer in self._layer_summary.values() + ] @property def total_training_modes(self) -> dict[str, int]: modes = [layer.training for layer in self._model.modules()] modes = modes[1:] # exclude the root module return {"train": modes.count(True), "eval": modes.count(False)} + # modes = [ + # ( + # (ModelSummaryTrainingMode.TRAIN if layer.training else ModelSummaryTrainingMode.EVAL) + # if any(p.requires_grad for p in layer.parameters()) + # else ModelSummaryTrainingMode.FREEZE + # ).value + # for layer in islice(self._model.modules(), 1, None) # exclude the root module + # ] + # return { + # "train": modes.count(ModelSummaryTrainingMode.TRAIN.value), + # "eval": modes.count(ModelSummaryTrainingMode.EVAL.value), + # "freeze": modes.count(ModelSummaryTrainingMode.FREEZE.value), + # } @property def total_parameters(self) -> int: @@ -382,7 +414,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]: ("Name", self.layer_names), ("Type", self.layer_types), ("Params", list(map(get_human_readable_count, self.param_nums))), - ("Mode", ["train" if mode else "eval" for mode in self.training_modes]), + ("Mode", [mode.value for mode in self.training_modes]), ("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))), ] if self._model.example_input_array is not None: @@ -487,6 +519,8 @@ def _format_summary_table( summary += "Modules in train mode" summary += "\n" + s.format(total_training_modes["eval"], 10) summary += "Modules in eval mode" + # summary += "\n" + s.format(total_training_modes["freeze"], 10) + # summary += "Modules in freeze mode" summary += "\n" + s.format(get_human_readable_count(total_flops), 10) summary += "Total Flops"