4747from megatron .core .transformer .transformer_layer import TransformerLayer
4848
4949
50- def get_ep_layer_offset () :
50+ def get_ep_layer_offset (num_experts : int | None = None ) -> int :
5151 """
5252 Get the expert layer offset for the current model.
53- """
54- from megatron .training .global_vars import get_args
5553
56- args = get_args ()
54+ Args:
55+ num_experts: Total number of experts in the model. If None, returns 0.
56+
57+ Returns:
58+ The expert layer offset for the current EP rank.
59+ """
5760 ep_size = parallel_state .get_expert_model_parallel_world_size ()
5861 ep_rank = parallel_state .get_expert_model_parallel_rank ()
59- num_local_experts = args . num_experts // ep_size if args . num_experts else 0
62+ num_local_experts = num_experts // ep_size if num_experts else 0
6063 local_expert_offset = ep_rank * num_local_experts
6164
6265 return local_expert_offset
6366
6467
65- def get_total_num_experts () :
68+ def get_total_num_experts (num_experts : int | None = None ) -> int :
6669 """
6770 Get the total number of experts for the current model.
68- """
69- from megatron .training .global_vars import get_args
7071
71- args = get_args ()
72- return args .num_experts if args .num_experts else 0
72+ Args:
73+ num_experts: Total number of experts in the model. If None, returns 0.
74+
75+ Returns:
76+ The total number of experts.
77+ """
78+ return num_experts if num_experts else 0
7379
7480
7581def get_expert_index_from_key (key ):
@@ -96,12 +102,19 @@ def get_expert_index_from_key(key):
96102 return None
97103
98104
99- def handle_experts_in_state_dict (state_dict ):
105+ def handle_experts_in_state_dict (state_dict , num_experts : int | None = None ):
100106 """
101107 Rewrite expert keys in state dict.
108+
109+ Args:
110+ state_dict: The state dictionary to process.
111+ num_experts: Total number of experts in the model. If None, no expert processing occurs.
112+
113+ Returns:
114+ The processed state dictionary with rewritten expert keys.
102115 """
103- local_expert_start = get_ep_layer_offset ()
104- local_expert_end = get_total_num_experts ()
116+ local_expert_start = get_ep_layer_offset (num_experts )
117+ local_expert_end = get_total_num_experts (num_experts )
105118
106119 def should_keep_expert_key (expert_index ):
107120 """Determine if this rank should keep this expert key based on expert index"""
@@ -147,9 +160,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict):
147160 return state_dict
148161
149162
150- def expert_param_local_key (key ):
151- """Get the module parameter corresponding to the key."""
152- local_expert_offset = get_ep_layer_offset ()
163+ def expert_param_local_key (key : str , num_experts : int | None = None ) -> str :
164+ """Get the module parameter corresponding to the key.
165+
166+ Args:
167+ key: The parameter key to process.
168+ num_experts: Total number of experts in the model. If None, no expert processing occurs.
169+
170+ Returns:
171+ The local parameter key with adjusted expert indices.
172+ """
173+ local_expert_offset = get_ep_layer_offset (num_experts )
153174 expert_index = get_expert_index_from_key (key )
154175 if expert_index is not None :
155176 new_expert_index = expert_index - local_expert_offset
@@ -174,6 +195,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict):
174195 """
175196 assert HAVE_MEGATRON_FSDP , "This function requires Megatron-FSDP to be installed."
176197
198+ # Extract num_experts from model config for expert parameter processing
199+ num_experts = model .config .num_moe_experts if hasattr (model , 'config' ) else None
200+
177201 def intersection (s1 , s2 ):
178202 # Only works for step=1
179203 start = max (s1 .start , s2 .start )
@@ -297,7 +321,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param
297321 new_opt_state_dict [f"{ key } _w" ] = opt_state_dict [key ].copy ()
298322 new_opt_state_dict [f"{ key } _v" ] = opt_state_dict [key ].copy ()
299323 for subkey in ["exp_avg" , "exp_avg_sq" ]:
300- dist_param = model .get_parameter (expert_param_local_key (key [len ("module." ) :]))
324+ dist_param = model .get_parameter (
325+ expert_param_local_key (key [len ("module." ) :], num_experts )
326+ )
301327 weight_w , weight_v = split_swiglu_linear_fc1 (
302328 opt_state_dict [key ][subkey ],
303329 dist_param ,
@@ -426,6 +452,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path):
426452def get_global_unique_param_name (model_chunks , param ):
427453 """
428454 Get the global unique parameter name for a given model and parameter.
455+
456+ Args:
457+ model_chunks: List of model chunks to search for the parameter.
458+ param: The parameter to find the name for.
459+
460+ Returns:
461+ The global unique parameter name.
429462 """
430463 param_name = None
431464 for model in model_chunks :
@@ -450,6 +483,7 @@ def get_global_unique_param_name(model_chunks, param):
450483 param_name = re .sub (r"layers\.(\d+)" , f"layers.{ tf_layer_number - 1 } " , param_name )
451484
452485 # Get EP unique parameter name
453- param_name = list (handle_experts_in_state_dict ({param_name : None }).keys ())[0 ]
486+ num_experts = model_chunks [0 ].config .num_moe_experts if model_chunks else None
487+ param_name = list (handle_experts_in_state_dict ({param_name : None }, num_experts ).keys ())[0 ]
454488
455489 return param_name
0 commit comments