Skip to content

Commit d29a1db

Browse files
committed
Bring over upgraded model_size_Wb from branch
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 72f9c67 commit d29a1db

File tree

1 file changed

+43
-27
lines changed

1 file changed

+43
-27
lines changed

fms_mo/fx/utils.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222

2323
# Third Party
24+
import pandas as pd
2425
import torch
2526

2627
# Local
@@ -433,7 +434,7 @@ def get_target_op_from_mod_or_str(mod_or_str, verbose=False):
433434
#############
434435

435436

436-
def model_size_Wb(mod, unit="MB"):
437+
def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
437438
"""Checks model size, only count weight and bias
438439
439440
NOTE:
@@ -447,46 +448,61 @@ def model_size_Wb(mod, unit="MB"):
447448
Returns:
448449
float: model size in desired unit
449450
"""
451+
450452
mem_use = 0
451-
Nint8 = 0
452-
Nfp16 = 0
453-
Nbf16 = 0
454-
Nfp32 = 0
453+
if unit not in ["MB", "GB"]:
454+
logger.warning(
455+
f"Unrecognized unit for memory summary: {unit}. Will use MB instead."
456+
)
457+
unit = "MB"
458+
459+
summary_weights = {"layer": [], "shape": [], f"mem ({unit})": [], "dtype": []}
455460
for n, m in mod.named_modules():
456461
w = getattr(m, "weight", None)
462+
w_dtype, w_shape = None, None
457463
if callable(w): # see Note 1.
458464
w_mat, b_mat = w()[:2]
459-
mem_use += (
465+
mem_use = (
460466
w_mat.numel() * w_mat.element_size()
461467
+ b_mat.numel() * b_mat.element_size()
462468
)
463-
if w_mat.dtype in [torch.qint8, torch.quint8]:
464-
Nint8 += 1
465-
else:
466-
logger.info(f"Parameter {n} should be int8 but is {w_mat.dtype}")
469+
w_dtype = w_mat.dtype
470+
w_shape = w_mat.shape
471+
467472
elif isinstance(w, torch.Tensor):
468-
mem_use += w.numel() * w.element_size()
473+
mem_use = w.numel() * w.element_size()
469474
if hasattr(m, "bias") and m.bias is not None:
470475
mem_use += m.bias.numel() * m.bias.element_size()
471-
if w.dtype == torch.float32:
472-
Nfp32 += 1
473-
elif w.dtype == torch.float16:
474-
Nfp16 += 1
475-
elif w.dtype == torch.bfloat16:
476-
Nbf16 += 1
477-
else:
478-
logger.warning(f"Detected parameter {n} of data type {w.dtype}.")
479-
logger.info(
480-
f"[check model size] Found {Nint8} INT8, {Nfp16} FP16, {Nbf16} BF16, and {Nfp32} "
481-
"FP32 W/b tensors in this model."
476+
w_dtype = w.dtype
477+
w_shape = w.shape
478+
479+
if w_shape:
480+
mem_use = mem_use / 1e9 if unit == "GB" else mem_use / 1e6
481+
482+
summary_weights["layer"].append(n)
483+
summary_weights["shape"].append(w_shape)
484+
summary_weights[f"mem ({unit})"].append(mem_use)
485+
summary_weights["dtype"].append(w_dtype)
486+
487+
df_summary_weights = pd.DataFrame(summary_weights)
488+
logger_or_print = logger.info if print_to_file else print
489+
logger_or_print("[check model size] Summary of W/b tensors in this model:")
490+
logger_or_print(
491+
"\n%s",
492+
str(
493+
pd.pivot_table(
494+
df_summary_weights,
495+
index="dtype",
496+
values=["layer", f"mem ({unit})"],
497+
aggfunc={"layer": "count", f"mem ({unit})": "sum"},
498+
)
499+
),
482500
)
483501

484-
if unit == "GB":
485-
mem_use /= 1e9
486-
else:
487-
mem_use /= 1e6
502+
if show_details:
503+
logger_or_print(df_summary_weights.to_markdown())
488504

489-
return mem_use
505+
return df_summary_weights[f"mem ({unit})"].sum().item()
490506

491507

492508
def plot_graph_module(

0 commit comments

Comments
 (0)