2222import torch
2323import torch .nn as nn
2424from torch import Tensor
25+ from torch .utils .flop_counter import FlopCounterMode
2526from torch .utils .hooks import RemovableHandle
2627
2728import lightning .pytorch as pl
2829from lightning .fabric .utilities .distributed import _is_dtensor
30+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_4
2931from lightning .pytorch .utilities .model_helpers import _ModuleMode
3032from lightning .pytorch .utilities .rank_zero import WarningCache
3133
@@ -180,29 +182,31 @@ class ModelSummary:
180182 ...
181183 >>> model = LitModel()
182184 >>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
183- | Name | Type | Params | Mode | In sizes | Out sizes
184- --------------------------------------------------------------------
185- 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
186- --------------------------------------------------------------------
185+ | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
186+ ----------------------------------------------------------------------------
187+ 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
188+ ----------------------------------------------------------------------------
187189 132 K Trainable params
188190 0 Non-trainable params
189191 132 K Total params
190192 0.530 Total estimated model params size (MB)
191193 3 Modules in train mode
192194 0 Modules in eval mode
195+ 2.6 M Total Flops
193196 >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
194- | Name | Type | Params | Mode | In sizes | Out sizes
195- ----------------------------------------------------------------------
196- 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
197- 1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
198- 2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
199- ----------------------------------------------------------------------
197+ | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
198+ ------------------------------------------------------------------------------
199+ 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
200+ 1 | net.0 | Linear | 131 K | train | 2.6 M | [10, 256] | [10, 512]
201+ 2 | net.1 | BatchNorm1d | 1.0 K | train | 0 | [10, 512] | [10, 512]
202+ ------------------------------------------------------------------------------
200203 132 K Trainable params
201204 0 Non-trainable params
202205 132 K Total params
203206 0.530 Total estimated model params size (MB)
204207 3 Modules in train mode
205208 0 Modules in eval mode
209+ 2.6 M Total Flops
206210
207211 """
208212
@@ -212,6 +216,13 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
212216 if not isinstance (max_depth , int ) or max_depth < - 1 :
213217 raise ValueError (f"`max_depth` can be -1, 0 or > 0, got { max_depth } ." )
214218
219+ # The max-depth needs to be plus one because the root module is already counted as depth 0.
220+ self ._flop_counter = FlopCounterMode (
221+ mods = None if _TORCH_GREATER_EQUAL_2_4 else self ._model ,
222+ display = False ,
223+ depth = max_depth + 1 ,
224+ )
225+
215226 self ._max_depth = max_depth
216227 self ._layer_summary = self .summarize ()
217228 # 1 byte -> 8 bits
@@ -279,6 +290,22 @@ def total_layer_params(self) -> int:
279290 def model_size (self ) -> float :
280291 return self .total_parameters * self ._precision_megabytes
281292
293+ @property
294+ def total_flops (self ) -> int :
295+ return self ._flop_counter .get_total_flops ()
296+
297+ @property
298+ def flop_counts (self ) -> dict [str , dict [Any , int ]]:
299+ flop_counts = self ._flop_counter .get_flop_counts ()
300+ ret = {
301+ name : flop_counts .get (
302+ f"{ type (self ._model ).__name__ } .{ name } " ,
303+ {},
304+ )
305+ for name in self .layer_names
306+ }
307+ return ret
308+
282309 def summarize (self ) -> dict [str , LayerSummary ]:
283310 summary = OrderedDict ((name , LayerSummary (module )) for name , module in self .named_modules )
284311 if self ._model .example_input_array is not None :
@@ -307,8 +334,18 @@ def _forward_example_input(self) -> None:
307334 mode .capture (model )
308335 model .eval ()
309336
337+ # FlopCounterMode does not support ScriptModules before torch 2.4.0, so we use a null context
338+ flop_context = (
339+ contextlib .nullcontext ()
340+ if (
341+ not _TORCH_GREATER_EQUAL_2_4
342+ and any (isinstance (m , torch .jit .ScriptModule ) for m in self ._model .modules ())
343+ )
344+ else self ._flop_counter
345+ )
346+
310347 forward_context = contextlib .nullcontext () if trainer is None else trainer .precision_plugin .forward_context ()
311- with torch .no_grad (), forward_context :
348+ with torch .no_grad (), forward_context , flop_context :
312349 # let the model hooks collect the input- and output shapes
313350 if isinstance (input_ , (list , tuple )):
314351 model (* input_ )
@@ -330,6 +367,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
330367 ("Type" , self .layer_types ),
331368 ("Params" , list (map (get_human_readable_count , self .param_nums ))),
332369 ("Mode" , ["train" if mode else "eval" for mode in self .training_modes ]),
370+ ("FLOPs" , list (map (get_human_readable_count , (sum (x .values ()) for x in self .flop_counts .values ())))),
333371 ]
334372 if self ._model .example_input_array is not None :
335373 arrays .append (("In sizes" , [str (x ) for x in self .in_sizes ]))
@@ -349,6 +387,7 @@ def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], t
349387 layer_summaries ["Type" ].append (NOT_APPLICABLE )
350388 layer_summaries ["Params" ].append (get_human_readable_count (total_leftover_params ))
351389 layer_summaries ["Mode" ].append (NOT_APPLICABLE )
390+ layer_summaries ["FLOPs" ].append (NOT_APPLICABLE )
352391 if "In sizes" in layer_summaries :
353392 layer_summaries ["In sizes" ].append (NOT_APPLICABLE )
354393 if "Out sizes" in layer_summaries :
@@ -361,8 +400,16 @@ def __str__(self) -> str:
361400 trainable_parameters = self .trainable_parameters
362401 model_size = self .model_size
363402 total_training_modes = self .total_training_modes
364-
365- return _format_summary_table (total_parameters , trainable_parameters , model_size , total_training_modes , * arrays )
403+ total_flops = self .total_flops
404+
405+ return _format_summary_table (
406+ total_parameters ,
407+ trainable_parameters ,
408+ model_size ,
409+ total_training_modes ,
410+ total_flops ,
411+ * arrays ,
412+ )
366413
367414 def __repr__ (self ) -> str :
368415 return str (self )
@@ -383,6 +430,7 @@ def _format_summary_table(
383430 trainable_parameters : int ,
384431 model_size : float ,
385432 total_training_modes : dict [str , int ],
433+ total_flops : int ,
386434 * cols : tuple [str , list [str ]],
387435) -> str :
388436 """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
@@ -423,6 +471,8 @@ def _format_summary_table(
423471 summary += "Modules in train mode"
424472 summary += "\n " + s .format (total_training_modes ["eval" ], 10 )
425473 summary += "Modules in eval mode"
474+ summary += "\n " + s .format (get_human_readable_count (total_flops ), 10 )
475+ summary += "Total Flops"
426476
427477 return summary
428478
0 commit comments