forked from rasbt/LLMs-from-scratch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory_estimator_moe.py
More file actions
127 lines (103 loc) · 4.1 KB
/
memory_estimator_moe.py
File metadata and controls
127 lines (103 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import argparse
DTYPE_BYTES = {
"fp32": 4,
"bf16": 2,
"fp16": 2,
"fp8": 1,
"int8": 1,
}
def bytes_convert(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def get_num_param_matrices(ffn_type):
if ffn_type == "gelu":
return 2
elif ffn_type == "swiglu":
return 3
else:
raise ValueError("--ffn_type must be 'gelu' or 'swiglu'")
def ffn_params(emb_dim, hidden_dim, ffn_type):
return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim
def router_params(emb_dim, num_experts):
return emb_dim * num_experts
def estimate_params_and_hidden(
emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False
):
P_dense = ffn_params(emb_dim, hidden_dim, ffn_type)
R = router_params(emb_dim, num_experts)
if match_dense:
num_param_matrices = get_num_param_matrices(ffn_type)
num = P_dense - R
den = num_experts * num_param_matrices * emb_dim
if num <= 0:
raise ValueError("Dense layer too small for requested num_experts.")
moe_hidden_dim = int(round(num / float(den)))
else:
moe_hidden_dim = hidden_dim
per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type)
moe_total = num_experts * per_expert_params + R
return {
"dense_params": P_dense,
"router": R,
"moe_hidden_dim": moe_hidden_dim,
"per_expert_params": per_expert_params,
"moe_total": moe_total,
}
def main():
p = argparse.ArgumentParser(
description="Estimate FFN vs MoE parameter memory"
)
p.add_argument("--emb_dim", type=int, required=True,
help="Model embedding dimension.")
p.add_argument("--hidden_dim", type=int, required=True,
help="Dense FFN intermediate size (hidden dimension).")
p.add_argument("--ffn_type", choices=["gelu", "swiglu"], default="swiglu")
p.add_argument("--num_experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="bf16")
p.add_argument(
"--match_dense",
action="store_true",
help=("Auto-set per-expert hidden so MoE total params ~= dense FFN params "
"(router included)."),
)
args = p.parse_args()
bytes_per_elem = DTYPE_BYTES[args.dtype]
res = estimate_params_and_hidden(
emb_dim=args.emb_dim,
hidden_dim=args.hidden_dim,
ffn_type=args.ffn_type,
num_experts=args.num_experts,
match_dense=args.match_dense,
)
moe_active_params_per_token = (
res["router"] + args.top_k * res["per_expert_params"]
)
print("==== Config ====")
print(f"{'emb_dim':23}: {args.emb_dim}")
print(f"{'hidden_dim':23}: {args.hidden_dim}")
print(f"{'ffn_type':23}: {args.ffn_type}")
print(f"{'num_experts':23}: {args.num_experts}")
print(f"{'top_k':23}: {args.top_k}")
print(f"{'dtype':23}: {args.dtype} ({bytes_per_elem} Bytes/elem)")
print(f"{'match_dense':23}: {args.match_dense}")
print()
print("==== Model weights (parameters) ====")
print(f"{'Dense FFN params':23}: {res['dense_params']:,} "
f"({bytes_convert(res['dense_params'] * bytes_per_elem)})")
print(f"{'Per-expert params':23}: {res['per_expert_params']:,} "
f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})")
print(f"{'Router params':23}: {res['router']:,} "
f"({bytes_convert(res['router'] * bytes_per_elem)})")
print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} "
f"({bytes_convert(res['moe_total'] * bytes_per_elem)})")
print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} "
f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})")
print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}")
print()
if __name__ == "__main__":
main()