Skip to content

Commit 3d3ead1

Browse files
committed
update ModelSummary
1 parent a651975 commit 3d3ead1

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,14 @@ def num_parameters(self) -> int:
144144
def training(self) -> bool:
145145
"""Returns whether the module is in training mode."""
146146
return self._module.training
147-
147+
148+
@property
149+
def requires_grad(self) -> bool:
150+
"""Returns whether the module is requires grad."""
151+
if self.num_parameters > 0:
152+
return any([param.requires_grad for name, param in self._module.named_parameters()])
153+
else:
154+
return True
148155

149156
class ModelSummary:
150157
"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.LightningModule`.
@@ -266,7 +273,7 @@ def param_nums(self) -> list[int]:
266273

267274
@property
268275
def training_modes(self) -> list[bool]:
269-
return [layer.training for layer in self._layer_summary.values()]
276+
return [(2 if layer.training else 1) if layer.requires_grad else 0 for layer in self._layer_summary.values()]
270277

271278
@property
272279
def total_training_modes(self) -> dict[str, int]:
@@ -361,12 +368,13 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
361368
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
362369
363370
"""
371+
param_mode = {0: "freeze", 1: "eval", 2: "train"}
364372
arrays = [
365373
(" ", list(map(str, range(len(self._layer_summary))))),
366374
("Name", self.layer_names),
367375
("Type", self.layer_types),
368376
("Params", list(map(get_human_readable_count, self.param_nums))),
369-
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
377+
("Mode", [param_mode[mode] for mode in self.training_modes]),
370378
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
371379
]
372380
if self._model.example_input_array is not None:

0 commit comments

Comments
 (0)