Skip to content

Commit 88ac3dc

Browse files
Merge pull request #88 from andrea-fasoli/util_update
feat: expand detection of data types in model size estimation
2 parents 51d4cf1 + d29a1db commit 88ac3dc

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

fms_mo/fx/utils.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222

2323
# Third Party
24+
import pandas as pd
2425
import torch
2526

2627
# Local
@@ -433,7 +434,7 @@ def get_target_op_from_mod_or_str(mod_or_str, verbose=False):
433434
#############
434435

435436

436-
def model_size_Wb(mod, unit="MB"):
437+
def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
437438
"""Checks model size, only count weight and bias
438439
439440
NOTE:
@@ -447,40 +448,61 @@ def model_size_Wb(mod, unit="MB"):
447448
Returns:
448449
float: model size in desired unit
449450
"""
451+
450452
mem_use = 0
451-
Nint8 = 0
452-
Nfp32 = 0
453+
if unit not in ["MB", "GB"]:
454+
logger.warning(
455+
f"Unrecognized unit for memory summary: {unit}. Will use MB instead."
456+
)
457+
unit = "MB"
458+
459+
summary_weights = {"layer": [], "shape": [], f"mem ({unit})": [], "dtype": []}
453460
for n, m in mod.named_modules():
454461
w = getattr(m, "weight", None)
462+
w_dtype, w_shape = None, None
455463
if callable(w): # see Note 1.
456464
w_mat, b_mat = w()[:2]
457-
mem_use += (
465+
mem_use = (
458466
w_mat.numel() * w_mat.element_size()
459467
+ b_mat.numel() * b_mat.element_size()
460468
)
461-
if w_mat.dtype in [torch.qint8, torch.quint8]:
462-
Nint8 += 1
463-
else:
464-
logger.info(f"Parameter {n} should be int8 but is {w_mat.dtype}")
469+
w_dtype = w_mat.dtype
470+
w_shape = w_mat.shape
471+
465472
elif isinstance(w, torch.Tensor):
466-
mem_use += w.numel() * w.element_size()
473+
mem_use = w.numel() * w.element_size()
467474
if hasattr(m, "bias") and m.bias is not None:
468475
mem_use += m.bias.numel() * m.bias.element_size()
469-
if w.dtype == torch.float32:
470-
Nfp32 += 1
471-
else:
472-
logger.info(f"Parameter {n} should be fp32 but is {w.dtype}")
473-
logger.info(
474-
f"[check model size] Found {Nint8} INT8 and {Nfp32} "
475-
"FP32 W/b tensors in this model."
476+
w_dtype = w.dtype
477+
w_shape = w.shape
478+
479+
if w_shape:
480+
mem_use = mem_use / 1e9 if unit == "GB" else mem_use / 1e6
481+
482+
summary_weights["layer"].append(n)
483+
summary_weights["shape"].append(w_shape)
484+
summary_weights[f"mem ({unit})"].append(mem_use)
485+
summary_weights["dtype"].append(w_dtype)
486+
487+
df_summary_weights = pd.DataFrame(summary_weights)
488+
logger_or_print = logger.info if print_to_file else print
489+
logger_or_print("[check model size] Summary of W/b tensors in this model:")
490+
logger_or_print(
491+
"\n%s",
492+
str(
493+
pd.pivot_table(
494+
df_summary_weights,
495+
index="dtype",
496+
values=["layer", f"mem ({unit})"],
497+
aggfunc={"layer": "count", f"mem ({unit})": "sum"},
498+
)
499+
),
476500
)
477501

478-
if unit == "GB":
479-
mem_use /= 1e9
480-
else:
481-
mem_use /= 1e6
502+
if show_details:
503+
logger_or_print(df_summary_weights.to_markdown())
482504

483-
return mem_use
505+
return df_summary_weights[f"mem ({unit})"].sum().item()
484506

485507

486508
def plot_graph_module(

0 commit comments

Comments
 (0)