@@ -144,7 +144,14 @@ def num_parameters(self) -> int:
144144 def training (self ) -> bool :
145145 """Returns whether the module is in training mode."""
146146 return self ._module .training
147-
147+
148+ @property
149+ def requires_grad (self ) -> bool :
150+ """Returns whether the module is requires grad."""
151+ if self .num_parameters > 0 :
152+ return any ([param .requires_grad for name , param in self ._module .named_parameters ()])
153+ else :
154+ return True
148155
149156class ModelSummary :
150157 """Generates a summary of all layers in a :class:`~lightning.pytorch.core.LightningModule`.
@@ -266,7 +273,7 @@ def param_nums(self) -> list[int]:
266273
267274 @property
268275 def training_modes (self ) -> list [bool ]:
269- return [layer .training for layer in self ._layer_summary .values ()]
276+ return [( 2 if layer .training else 1 ) if layer . requires_grad else 0 for layer in self ._layer_summary .values ()]
270277
271278 @property
272279 def total_training_modes (self ) -> dict [str , int ]:
@@ -361,12 +368,13 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
361368 Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
362369
363370 """
371+ param_mode = {0 : "freeze" , 1 : "eval" , 2 : "train" }
364372 arrays = [
365373 (" " , list (map (str , range (len (self ._layer_summary ))))),
366374 ("Name" , self .layer_names ),
367375 ("Type" , self .layer_types ),
368376 ("Params" , list (map (get_human_readable_count , self .param_nums ))),
369- ("Mode" , ["train" if mode else "eval" for mode in self .training_modes ]),
377+ ("Mode" , [param_mode [ mode ] for mode in self .training_modes ]),
370378 ("FLOPs" , list (map (get_human_readable_count , (sum (x .values ()) for x in self .flop_counts .values ())))),
371379 ]
372380 if self ._model .example_input_array is not None :
0 commit comments