File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -449,6 +449,8 @@ def model_size_Wb(mod, unit="MB"):
449449 """
450450 mem_use = 0
451451 Nint8 = 0
452+ Nfp16 = 0
453+ Nbf16 = 0
452454 Nfp32 = 0
453455 for n , m in mod .named_modules ():
454456 w = getattr (m , "weight" , None )
@@ -468,10 +470,14 @@ def model_size_Wb(mod, unit="MB"):
468470 mem_use += m .bias .numel () * m .bias .element_size ()
469471 if w .dtype == torch .float32 :
470472 Nfp32 += 1
473+ elif w .dtype == torch .float16 :
474+ Nfp16 += 1
475+ elif w .dtype == torch .bfloat16 :
476+ Nbf16 += 1
471477 else :
472- logger .info (f"Parameter { n } should be fp32 but is { w .dtype } " )
478+ logger .warning (f"Detected parameter { n } of data type { w .dtype } . " )
473479 logger .info (
474- f"[check model size] Found { Nint8 } INT8 and { Nfp32 } "
480+ f"[check model size] Found { Nint8 } INT8, { Nfp16 } FP16, { Nbf16 } BF16, and { Nfp32 } "
475481 "FP32 W/b tensors in this model."
476482 )
477483
You can’t perform that action at this time.
0 commit comments