Skip to content

Commit 72f9c67

Browse files
committed
feature: expand detection of data types in model size estimation
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent a0c2aae commit 72f9c67

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

fms_mo/fx/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ def model_size_Wb(mod, unit="MB"):
449449
"""
450450
mem_use = 0
451451
Nint8 = 0
452+
Nfp16 = 0
453+
Nbf16 = 0
452454
Nfp32 = 0
453455
for n, m in mod.named_modules():
454456
w = getattr(m, "weight", None)
@@ -468,10 +470,14 @@ def model_size_Wb(mod, unit="MB"):
468470
mem_use += m.bias.numel() * m.bias.element_size()
469471
if w.dtype == torch.float32:
470472
Nfp32 += 1
473+
elif w.dtype == torch.float16:
474+
Nfp16 += 1
475+
elif w.dtype == torch.bfloat16:
476+
Nbf16 += 1
471477
else:
472-
logger.info(f"Parameter {n} should be fp32 but is {w.dtype}")
478+
logger.warning(f"Detected parameter {n} of data type {w.dtype}.")
473479
logger.info(
474-
f"[check model size] Found {Nint8} INT8 and {Nfp32} "
480+
f"[check model size] Found {Nint8} INT8, {Nfp16} FP16, {Nbf16} BF16, and {Nfp32} "
475481
"FP32 W/b tensors in this model."
476482
)
477483

0 commit comments

Comments
 (0)