Skip to content

Commit 1db5cd3

Browse files
ananthsubko3n1g
andauthored
cp: Cherry pick NVIDIA#2114 into core_dev_r0.15.0 (NVIDIA#2150)
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com> Co-authored-by: oliver könig <okoenig@nvidia.com>
1 parent ae99e0a commit 1db5cd3

File tree

3 files changed

+55
-20
lines changed

3 files changed

+55
-20
lines changed

megatron/core/optimizer/distrib_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,8 @@ def _param_name(self, param: torch.nn.Parameter) -> str:
11531153
"Ensure that each model chunk has unique parameter names."
11541154
)
11551155
name_to_param.update(_name_to_param)
1156-
name_to_param = handle_experts_in_state_dict(name_to_param)
1156+
num_experts = self.model_chunks[0].config.num_moe_experts if self.model_chunks else None
1157+
name_to_param = handle_experts_in_state_dict(name_to_param, num_experts)
11571158
self.param_to_name = {param: name for name, param in name_to_param.items()}
11581159
assert (
11591160
param in self.param_to_name

megatron/core/transformer/fsdp_dtensor_checkpoint.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,35 @@
4747
from megatron.core.transformer.transformer_layer import TransformerLayer
4848

4949

50-
def get_ep_layer_offset():
50+
def get_ep_layer_offset(num_experts: int | None = None) -> int:
5151
"""
5252
Get the expert layer offset for the current model.
53-
"""
54-
from megatron.training.global_vars import get_args
5553
56-
args = get_args()
54+
Args:
55+
num_experts: Total number of experts in the model. If None, returns 0.
56+
57+
Returns:
58+
The expert layer offset for the current EP rank.
59+
"""
5760
ep_size = parallel_state.get_expert_model_parallel_world_size()
5861
ep_rank = parallel_state.get_expert_model_parallel_rank()
59-
num_local_experts = args.num_experts // ep_size if args.num_experts else 0
62+
num_local_experts = num_experts // ep_size if num_experts else 0
6063
local_expert_offset = ep_rank * num_local_experts
6164

6265
return local_expert_offset
6366

6467

65-
def get_total_num_experts():
68+
def get_total_num_experts(num_experts: int | None = None) -> int:
6669
"""
6770
Get the total number of experts for the current model.
68-
"""
69-
from megatron.training.global_vars import get_args
7071
71-
args = get_args()
72-
return args.num_experts if args.num_experts else 0
72+
Args:
73+
num_experts: Total number of experts in the model. If None, returns 0.
74+
75+
Returns:
76+
The total number of experts.
77+
"""
78+
return num_experts if num_experts else 0
7379

7480

7581
def get_expert_index_from_key(key):
@@ -96,12 +102,19 @@ def get_expert_index_from_key(key):
96102
return None
97103

98104

99-
def handle_experts_in_state_dict(state_dict):
105+
def handle_experts_in_state_dict(state_dict, num_experts: int | None = None):
100106
"""
101107
Rewrite expert keys in state dict.
108+
109+
Args:
110+
state_dict: The state dictionary to process.
111+
num_experts: Total number of experts in the model. If None, no expert processing occurs.
112+
113+
Returns:
114+
The processed state dictionary with rewritten expert keys.
102115
"""
103-
local_expert_start = get_ep_layer_offset()
104-
local_expert_end = get_total_num_experts()
116+
local_expert_start = get_ep_layer_offset(num_experts)
117+
local_expert_end = get_total_num_experts(num_experts)
105118

106119
def should_keep_expert_key(expert_index):
107120
"""Determine if this rank should keep this expert key based on expert index"""
@@ -147,9 +160,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict):
147160
return state_dict
148161

149162

150-
def expert_param_local_key(key):
151-
"""Get the module parameter corresponding to the key."""
152-
local_expert_offset = get_ep_layer_offset()
163+
def expert_param_local_key(key: str, num_experts: int | None = None) -> str:
164+
"""Get the module parameter corresponding to the key.
165+
166+
Args:
167+
key: The parameter key to process.
168+
num_experts: Total number of experts in the model. If None, no expert processing occurs.
169+
170+
Returns:
171+
The local parameter key with adjusted expert indices.
172+
"""
173+
local_expert_offset = get_ep_layer_offset(num_experts)
153174
expert_index = get_expert_index_from_key(key)
154175
if expert_index is not None:
155176
new_expert_index = expert_index - local_expert_offset
@@ -174,6 +195,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict):
174195
"""
175196
assert HAVE_MEGATRON_FSDP, "This function requires Megatron-FSDP to be installed."
176197

198+
# Extract num_experts from model config for expert parameter processing
199+
num_experts = model.config.num_moe_experts if hasattr(model, 'config') else None
200+
177201
def intersection(s1, s2):
178202
# Only works for step=1
179203
start = max(s1.start, s2.start)
@@ -297,7 +321,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param
297321
new_opt_state_dict[f"{key}_w"] = opt_state_dict[key].copy()
298322
new_opt_state_dict[f"{key}_v"] = opt_state_dict[key].copy()
299323
for subkey in ["exp_avg", "exp_avg_sq"]:
300-
dist_param = model.get_parameter(expert_param_local_key(key[len("module.") :]))
324+
dist_param = model.get_parameter(
325+
expert_param_local_key(key[len("module.") :], num_experts)
326+
)
301327
weight_w, weight_v = split_swiglu_linear_fc1(
302328
opt_state_dict[key][subkey],
303329
dist_param,
@@ -426,6 +452,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path):
426452
def get_global_unique_param_name(model_chunks, param):
427453
"""
428454
Get the global unique parameter name for a given model and parameter.
455+
456+
Args:
457+
model_chunks: List of model chunks to search for the parameter.
458+
param: The parameter to find the name for.
459+
460+
Returns:
461+
The global unique parameter name.
429462
"""
430463
param_name = None
431464
for model in model_chunks:
@@ -450,6 +483,7 @@ def get_global_unique_param_name(model_chunks, param):
450483
param_name = re.sub(r"layers\.(\d+)", f"layers.{tf_layer_number - 1}", param_name)
451484

452485
# Get EP unique parameter name
453-
param_name = list(handle_experts_in_state_dict({param_name: None}).keys())[0]
486+
num_experts = model_chunks[0].config.num_moe_experts if model_chunks else None
487+
param_name = list(handle_experts_in_state_dict({param_name: None}, num_experts).keys())[0]
454488

455489
return param_name

megatron/training/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def preprocess_fsdp_dtensor_state_dict(args, raw_state_dict, model):
870870
)
871871
state_dict["model"] = model_state_dict
872872
if args.num_experts:
873-
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"])
873+
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"], args.num_experts)
874874
preprocess_state_dict_for_uneven_dtensor(state_dict)
875875

876876
return state_dict

0 commit comments

Comments
 (0)