Skip to content

update ModelSummary #20945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3d3ead1
update ModelSummary
YChienHung Jun 26, 2025
b8a452c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2025
6263a9c
Merge branch 'master' into freeze-model-summary
YChienHung Jun 26, 2025
e462627
type from bool to int
YChienHung Jun 26, 2025
5664f02
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Jun 26, 2025
da7fafb
Merge branch 'master' into freeze-model-summary
YChienHung Jun 28, 2025
567de9c
Merge branch 'master' into freeze-model-summary
deependujha Jul 26, 2025
6e6df34
Merge branch 'Lightning-AI:master' into freeze-model-summary
YChienHung Aug 9, 2025
afee472
use enum
YChienHung Aug 9, 2025
6d35cde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
1c8157b
fix bugs
YChienHung Aug 9, 2025
e3014df
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Aug 9, 2025
a932382
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
0bf5cae
try fix errors
YChienHung Aug 9, 2025
4dafc88
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Aug 9, 2025
2bd762e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
971c6c0
retry
YChienHung Aug 9, 2025
674dd06
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Aug 9, 2025
1cae9c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
422a1c7
retry
YChienHung Aug 9, 2025
beede57
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Aug 9, 2025
fe02458
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
2323179
fix bug
YChienHung Aug 9, 2025
6b3f1b2
retry
YChienHung Aug 9, 2025
e93c45e
remove first model
YChienHung Aug 9, 2025
bd30a60
fix remove root error
YChienHung Aug 9, 2025
0d30668
continue
YChienHung Aug 9, 2025
50cff87
fix mode error
YChienHung Aug 9, 2025
ab2a46e
im tired
YChienHung Aug 9, 2025
6dae87e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2025
49e1c32
fix bug
YChienHung Aug 9, 2025
76804ee
Merge branch 'freeze-model-summary' of github.com:YChienHung/pytorch-…
YChienHung Aug 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import math
from collections import OrderedDict
from enum import Enum
from typing import Any, Optional, Union

import torch
Expand All @@ -41,6 +42,12 @@
NOT_APPLICABLE = "n/a"


class ModelSummaryTrainingMode(Enum):
TRAIN = "train"
EVAL = "eval"
FREEZE = "freeze"


class LayerSummary:
"""Summary class for a single layer in a :class:`~lightning.pytorch.core.LightningModule`. It collects the
following information:
Expand Down Expand Up @@ -146,6 +153,13 @@ def training(self) -> bool:
"""Returns whether the module is in training mode."""
return self._module.training

@property
def requires_grad(self) -> bool:
"""Returns whether the module requires grad."""
if self.num_parameters > 0:
return any(param.requires_grad for name, param in self._module.named_parameters())
return True


class ModelSummary:
"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.LightningModule`.
Expand Down Expand Up @@ -281,14 +295,32 @@ def param_nums(self) -> list[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]

@property
def training_modes(self) -> list[bool]:
return [layer.training for layer in self._layer_summary.values()]
def training_modes(self) -> list[int]:
return [
(ModelSummaryTrainingMode.TRAIN if layer.training else ModelSummaryTrainingMode.EVAL)
if layer.requires_grad
else ModelSummaryTrainingMode.FREEZE
for layer in self._layer_summary.values()
]

@property
def total_training_modes(self) -> dict[str, int]:
modes = [layer.training for layer in self._model.modules()]
modes = modes[1:] # exclude the root module
return {"train": modes.count(True), "eval": modes.count(False)}
# modes = [
# (
# (ModelSummaryTrainingMode.TRAIN if layer.training else ModelSummaryTrainingMode.EVAL)
# if any(p.requires_grad for p in layer.parameters())
# else ModelSummaryTrainingMode.FREEZE
# ).value
# for layer in islice(self._model.modules(), 1, None) # exclude the root module
# ]
# return {
# "train": modes.count(ModelSummaryTrainingMode.TRAIN.value),
# "eval": modes.count(ModelSummaryTrainingMode.EVAL.value),
# "freeze": modes.count(ModelSummaryTrainingMode.FREEZE.value),
# }

@property
def total_parameters(self) -> int:
Expand Down Expand Up @@ -382,7 +414,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
("Name", self.layer_names),
("Type", self.layer_types),
("Params", list(map(get_human_readable_count, self.param_nums))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
("Mode", [mode.value for mode in self.training_modes]),
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
]
if self._model.example_input_array is not None:
Expand Down Expand Up @@ -487,6 +519,8 @@ def _format_summary_table(
summary += "Modules in train mode"
summary += "\n" + s.format(total_training_modes["eval"], 10)
summary += "Modules in eval mode"
# summary += "\n" + s.format(total_training_modes["freeze"], 10)
# summary += "Modules in freeze mode"
summary += "\n" + s.format(get_human_readable_count(total_flops), 10)
summary += "Total Flops"

Expand Down
Loading