2121import os
2222
2323# Third Party
24+ import pandas as pd
2425import torch
2526
2627# Local
@@ -433,7 +434,7 @@ def get_target_op_from_mod_or_str(mod_or_str, verbose=False):
433434#############
434435
435436
436- def model_size_Wb (mod , unit = "MB" ):
437+ def model_size_Wb (mod , unit = "MB" , print_to_file = True , show_details = False ):
437438 """Checks model size, only count weight and bias
438439
439440 NOTE:
@@ -447,40 +448,61 @@ def model_size_Wb(mod, unit="MB"):
447448 Returns:
448449 float: model size in desired unit
449450 """
451+
450452 mem_use = 0
451- Nint8 = 0
452- Nfp32 = 0
453+ if unit not in ["MB" , "GB" ]:
454+ logger .warning (
455+ f"Unrecognized unit for memory summary: { unit } . Will use MB instead."
456+ )
457+ unit = "MB"
458+
459+ summary_weights = {"layer" : [], "shape" : [], f"mem ({ unit } )" : [], "dtype" : []}
453460 for n , m in mod .named_modules ():
454461 w = getattr (m , "weight" , None )
462+ w_dtype , w_shape = None , None
455463 if callable (w ): # see Note 1.
456464 w_mat , b_mat = w ()[:2 ]
457- mem_use + = (
465+ mem_use = (
458466 w_mat .numel () * w_mat .element_size ()
459467 + b_mat .numel () * b_mat .element_size ()
460468 )
461- if w_mat .dtype in [torch .qint8 , torch .quint8 ]:
462- Nint8 += 1
463- else :
464- logger .info (f"Parameter { n } should be int8 but is { w_mat .dtype } " )
469+ w_dtype = w_mat .dtype
470+ w_shape = w_mat .shape
471+
465472 elif isinstance (w , torch .Tensor ):
466- mem_use + = w .numel () * w .element_size ()
473+ mem_use = w .numel () * w .element_size ()
467474 if hasattr (m , "bias" ) and m .bias is not None :
468475 mem_use += m .bias .numel () * m .bias .element_size ()
469- if w .dtype == torch .float32 :
470- Nfp32 += 1
471- else :
472- logger .info (f"Parameter { n } should be fp32 but is { w .dtype } " )
473- logger .info (
474- f"[check model size] Found { Nint8 } INT8 and { Nfp32 } "
475- "FP32 W/b tensors in this model."
476+ w_dtype = w .dtype
477+ w_shape = w .shape
478+
479+ if w_shape :
480+ mem_use = mem_use / 1e9 if unit == "GB" else mem_use / 1e6
481+
482+ summary_weights ["layer" ].append (n )
483+ summary_weights ["shape" ].append (w_shape )
484+ summary_weights [f"mem ({ unit } )" ].append (mem_use )
485+ summary_weights ["dtype" ].append (w_dtype )
486+
487+ df_summary_weights = pd .DataFrame (summary_weights )
488+ logger_or_print = logger .info if print_to_file else print
489+ logger_or_print ("[check model size] Summary of W/b tensors in this model:" )
490+ logger_or_print (
491+ "\n %s" ,
492+ str (
493+ pd .pivot_table (
494+ df_summary_weights ,
495+ index = "dtype" ,
496+ values = ["layer" , f"mem ({ unit } )" ],
497+ aggfunc = {"layer" : "count" , f"mem ({ unit } )" : "sum" },
498+ )
499+ ),
476500 )
477501
478- if unit == "GB" :
479- mem_use /= 1e9
480- else :
481- mem_use /= 1e6
502+ if show_details :
503+ logger_or_print (df_summary_weights .to_markdown ())
482504
483- return mem_use
505+ return df_summary_weights [ f"mem ( { unit } )" ]. sum (). item ()
484506
485507
486508def plot_graph_module (
0 commit comments