1111from megatron .core .transformer .moe .moe_utils import get_align_size_for_quantization
1212from megatron .core .transformer .moe .experts import TEGroupedMLP
1313from megatron .core .transformer .moe .paged_stash import (
14+ check_paged_stash_overflow ,
1415 paged_stash_init_chunk_handler ,
1516 paged_stash_reset ,
1617)
1920from tests .unit_tests .test_utilities import Utils
2021
2122
23+ def _global_tokens_per_expert_from_local_routing_map (routing_map : torch .Tensor ) -> torch .Tensor :
24+ """Per-expert token counts from a local routing map, summed across the default process group.
25+
26+ ``routing_map`` is shaped [num_local_token_rows, num_experts] (as in
27+ ``_HybridEPManager``). Tests here assume world size equals expert-parallel size (all GPUs
28+ are EP ranks); ``all_reduce`` on the world group aggregates disjoint local maps.
29+ """
30+ counts = routing_map .sum (dim = 0 ).to (torch .int64 )
31+ if torch .distributed .is_initialized () and torch .distributed .get_world_size () > 1 :
32+ torch .distributed .all_reduce (counts , op = torch .distributed .ReduceOp .SUM )
33+ return counts
34+
35+
36+ def _tokens_per_expert_from_routing_map (routing_map : torch .Tensor , layer : MoELayer ) -> torch .Tensor :
37+ """Per-local-expert assignment counts from the routing map (columns for this EP rank)."""
38+ counts = _global_tokens_per_expert_from_local_routing_map (routing_map )
39+ idx = torch .as_tensor (layer .local_expert_indices , device = counts .device , dtype = torch .long )
40+ return counts [idx ].to (torch .int64 ).clone ()
41+
42+
43+ def _pad_token_counts_to_align_size (
44+ tokens_per_expert : torch .Tensor , pad_multiple : int
45+ ) -> torch .Tensor :
46+ """Round each count up to a multiple of ``pad_multiple`` (``n + (-n % m)`` like budget)."""
47+ t = tokens_per_expert .to (torch .int64 )
48+ return t + (- t % pad_multiple )
49+
50+
2251class MoEModelTestContainer :
2352 def __init__ (
2453 self ,
@@ -92,12 +121,19 @@ def __init__(
92121 moe_router_padding_for_fp8 = kwargs .get ("moe_router_padding_for_fp8" , True ),
93122 use_transformer_engine_op_fuser = kwargs .get ("use_transformer_engine_op_fuser" , False ),
94123 moe_mlp_glu_interleave_size = kwargs .get ("moe_mlp_glu_interleave_size" , None ),
95- moe_router_padding_for_quantization = kwargs .get ("moe_router_padding_for_quantization" , False ),
124+ moe_router_padding_for_quantization = kwargs .get (
125+ "moe_router_padding_for_quantization" , False
126+ ),
96127 gated_linear_unit = kwargs .get ("gated_linear_unit" , False ),
97128 activation_func = kwargs .get ("activation_func" , F .gelu ),
98129 moe_router_force_biased = kwargs .get ("moe_router_force_biased" , None ),
130+ stash_buffer_size_factor_cuda = 0.5 ,
131+ stash_buffer_size_factor_cpu = 1.5 ,
99132 )
100- self .moe_layer = self ._create_moe_layer (layer_number = 0 )
133+ self .moe_layers = [
134+ self ._create_moe_layer (layer_number = i ) for i in range (num_layers )
135+ ]
136+ self .moe_layer = self .moe_layers [0 ]
101137
102138 def _create_moe_layer (self , layer_number = 0 ):
103139 transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec (
@@ -114,43 +150,44 @@ def _create_moe_layer(self, layer_number=0):
114150 return moe_layer
115151
116152 def zero_grad (self ):
117- self .moe_layer .zero_grad ()
153+ for layer in self .moe_layers :
154+ layer .zero_grad ()
118155
119156 def __del__ (self ):
120157 torch .distributed .barrier ()
121158 torch .cuda .synchronize ()
122159 Utils .destroy_model_parallel ()
123160
124- def forward_backward (self , hidden_states ):
125- """Run one forward and backward pass through the MoE layer.
126-
127- Returns:
128- output: MoE layer output (detached).
129- hidden_states_grad: Gradient w.r.t. hidden_states.
130- routing_map: Token-to-expert routing map from the dispatcher (after forward).
131- tokens_per_expert: Number of tokens per local expert on this EP rank (after forward).
132- """
133- hidden_states = hidden_states .cuda ().requires_grad_ (True )
134- quantization_context = get_fp8_context (self .config )
135- with quantization_context :
136- output , _ = self .moe_layer (hidden_states )
137- # Capture routing_map and tokens_per_expert after forward (before backward)
138- comm = getattr (self .moe_layer .token_dispatcher , "_comm_manager" , None )
139- routing_map = getattr (comm , "routing_map" , None )
140- tokens_per_expert = (
141- comm .get_number_of_tokens_per_expert ()
142- if comm is not None and hasattr (comm , "get_number_of_tokens_per_expert" )
143- else None
144- )
145- # Use contiguous gradient to avoid non-contiguous grad in HybridEP combine backward
146- # (output.sum().backward() produces a broadcast gradient that is non-contiguous)
147- output .backward (torch .ones_like (output ))
148- return output .detach (), hidden_states .grad , routing_map , tokens_per_expert
149-
150161 def destroy (self ):
151162 Utils .destroy_model_parallel ()
152163
153164
165+ def _forward_backward_all_layers (container : MoEModelTestContainer , hidden_states : torch .Tensor ):
166+ """Forward/backward all MoE layers; returns output, input grad, last layer routing state."""
167+ initial_hidden_states = hidden_states .cuda ().requires_grad_ (True )
168+ hidden_states = initial_hidden_states
169+ quantization_context = get_fp8_context (container .config )
170+ with quantization_context :
171+ for layer in container .moe_layers :
172+ hidden_states , _ = layer (hidden_states )
173+ output = hidden_states
174+ last_layer = container .moe_layers [- 1 ]
175+ comm = getattr (last_layer .token_dispatcher , "_comm_manager" , None )
176+ routing_map = getattr (comm , "routing_map" , None )
177+ tokens_per_expert = (
178+ comm .get_number_of_tokens_per_expert ()
179+ if comm is not None and hasattr (comm , "get_number_of_tokens_per_expert" )
180+ else None
181+ )
182+ output .backward (torch .ones_like (output ))
183+ return (
184+ output .detach (),
185+ initial_hidden_states .grad ,
186+ routing_map ,
187+ tokens_per_expert ,
188+ )
189+
190+
154191def is_hybrid_ep_available ():
155192 from megatron .core .transformer .moe .fused_a2a import HAVE_HYBRIDEP
156193 return HAVE_HYBRIDEP
@@ -166,7 +203,8 @@ def teardown_method(self, method):
166203
167204 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
168205 @pytest .mark .internal
169- def test_forward_backward (self ):
206+ def test_forward_backward_4_layers (self ):
207+ """Test paged stashing with 4 MoE layers: ref run vs paged run match."""
170208 if not is_hybrid_ep_available ():
171209 pytest .skip ("Hybrid EP is not available" )
172210
@@ -177,7 +215,7 @@ def test_forward_backward(self):
177215 ep_size = 4 ,
178216 pp_size = 1 ,
179217 num_moe_experts = 8 ,
180- num_layers = 2 ,
218+ num_layers = 4 ,
181219 moe_router_topk = 2 ,
182220 moe_router_load_balancing_type = "aux_loss" ,
183221 moe_token_dispatcher_type = "flex" ,
@@ -197,11 +235,12 @@ def test_forward_backward(self):
197235 gated_linear_unit = True ,
198236 activation_func = F .silu ,
199237 )
200- if not isinstance (container .moe_layer .experts , TEGroupedMLP ) or not container .moe_layer .experts ._is_fused_impl_supported ():
238+ experts = container .moe_layer .experts
239+ fused_ok = isinstance (experts , TEGroupedMLP ) and experts ._is_fused_impl_supported ()
240+ if not fused_ok :
201241 container .destroy ()
202242 pytest .skip ("TEGroupedMLP fused impl not supported" )
203243
204- # [sequence_length, batch_size, hidden_size] for MoELayer.forward
205244 seq_length = 1024
206245 batch_size = 1
207246 hidden_size = container .config .hidden_size
@@ -210,32 +249,42 @@ def test_forward_backward(self):
210249 )
211250
212251 # First iteration: capture schedule, capacity, etc.
213- paged_stash_reset (True )
252+ paged_stash_reset (True , config = container . config )
214253 paged_stash_init_chunk_handler (1 , 0 )
215254 output_ref , hidden_states_grad_ref , routing_map_ref , tokens_per_expert_ref = (
216- container . forward_backward ( hidden_states )
255+ _forward_backward_all_layers ( container , hidden_states )
217256 )
218257
219258 container .zero_grad ()
220259
221260 # Second iteration: run with paged stash.
222- paged_stash_reset (True )
261+ paged_stash_reset (True , config = container . config )
223262 paged_stash_init_chunk_handler (1 , 0 )
224- output , hidden_states_grad , routing_map , tokens_per_expert = container . forward_backward (
225- hidden_states
263+ output , hidden_states_grad , routing_map , tokens_per_expert = _forward_backward_all_layers (
264+ container , hidden_states
226265 )
227266
228- # Verify output and input gradient match the first iteration.
229- torch .testing .assert_close (output , output_ref , atol = 1e-4 , rtol = 1e-4 )
230- torch .testing .assert_close (
231- hidden_states_grad , hidden_states_grad_ref , atol = 1e-4 , rtol = 1e-4
267+ overflow = check_paged_stash_overflow ()
268+ assert overflow .any ().item () == 0
269+
270+ assert torch .allclose (output , output_ref , atol = 1e-4 , rtol = 1e-4 ), (
271+ f"output != output_ref: max diff = { (output - output_ref ).abs ().max ().item ()} "
272+ )
273+ assert torch .allclose (hidden_states_grad , hidden_states_grad_ref , atol = 1e-4 , rtol = 1e-4 ), (
274+ f"hidden_states_grad != ref: max diff = "
275+ f"{ (hidden_states_grad - hidden_states_grad_ref ).abs ().max ().item ()} "
232276 )
233- # Routing and token counts available after forward (e.g. for debugging or further checks)
234277 if routing_map is not None and tokens_per_expert is not None :
235278 num_tokens_per_ep_rank = tokens_per_expert .sum ().item ()
236- assert num_tokens_per_ep_rank > 0
279+ assert num_tokens_per_ep_rank > 0 , (
280+ f"num_tokens_per_ep_rank={ num_tokens_per_ep_rank } (expected > 0)"
281+ )
237282 assert routing_map_ref is not None and tokens_per_expert_ref is not None
238- torch .testing .assert_close (tokens_per_expert , tokens_per_expert_ref )
283+ tpe_f = tokens_per_expert .float ()
284+ ref_f = tokens_per_expert_ref .float ()
285+ assert torch .allclose (tpe_f , ref_f , atol = 1e-4 , rtol = 1e-4 ), (
286+ f"tokens_per_expert != ref: max diff = { (tpe_f - ref_f ).abs ().max ().item ()} "
287+ )
239288
240289
241290@pytest .mark .skipif (not is_hybrid_ep_available (), reason = "Hybrid EP are not available" )
@@ -249,8 +298,7 @@ def teardown_method(self, method):
249298 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
250299 @pytest .mark .internal
251300 def test_overload_factor_and_over_budget (self ):
252- """Test budget computation (same as token_dispatcher lines 1017-1025) and assert
253- over_budget flag is set when tokens_per_ep_rank exceeds budget."""
301+ """Budget matches HybridEP setup_metadata; over_budget matches map-derived load."""
254302 if not is_hybrid_ep_available ():
255303 pytest .skip ("Hybrid EP is not available" )
256304
@@ -261,8 +309,8 @@ def test_overload_factor_and_over_budget(self):
261309 ep_size = 4 ,
262310 pp_size = 1 ,
263311 num_moe_experts = 8 ,
264- num_layers = 1 ,
265- moe_router_topk = 4 ,
312+ num_layers = 4 ,
313+ moe_router_topk = 2 ,
266314 moe_router_load_balancing_type = "aux_loss" ,
267315 moe_token_dispatcher_type = "flex" ,
268316 moe_permute_fusion = True ,
@@ -274,50 +322,84 @@ def test_overload_factor_and_over_budget(self):
274322 moe_use_legacy_grouped_gemm = False ,
275323 moe_paged_stash = True ,
276324 stash_modules = ["expert_fc1" , "moe_act" , "expert_fc2" ],
277- moe_expert_rank_capacity_factor = 1.0 ,
325+ moe_expert_rank_capacity_factor = 1.5 ,
278326 use_transformer_engine_op_fuser = True ,
279327 moe_mlp_glu_interleave_size = 32 ,
280328 moe_router_padding_for_quantization = True ,
281329 gated_linear_unit = True ,
282330 activation_func = F .silu ,
283331 moe_router_force_biased = 1 ,
284332 )
285- if not isinstance (container .moe_layer .experts , TEGroupedMLP ) or not container .moe_layer .experts ._is_fused_impl_supported ():
333+ experts = container .moe_layer .experts
334+ fused_ok = isinstance (experts , TEGroupedMLP ) and experts ._is_fused_impl_supported ()
335+ if not fused_ok :
286336 container .destroy ()
287337 pytest .skip ("TEGroupedMLP fused impl not supported" )
288338
289- seq_length = 4096
339+ seq_length = 1024
290340 batch_size = 1
291341 topk = container .config .moe_router_topk
292342 capacity_factor = container .config .moe_expert_rank_capacity_factor
293- hidden_size = container .config .hidden_size
294343 hidden_states = torch .randn (
295- (seq_length , batch_size , hidden_size ), dtype = torch .bfloat16
344+ (seq_length , batch_size , container . config . hidden_size ), dtype = torch .bfloat16
296345 )
297346
298- # Budget computed like token_dispatcher._HybridEPManager.setup_metadata (lines 1017-1025)
299- num_tokens = seq_length * batch_size
347+ num_tokens = seq_length * batch_size * topk
300348 pad_multiple = get_align_size_for_quantization (container .config )
301- budget = int (num_tokens * topk * capacity_factor )
349+ budget = int (num_tokens * capacity_factor )
302350 budget += - budget % pad_multiple
303351
304- paged_stash_reset (True )
352+ paged_stash_reset (True , config = container . config )
305353 paged_stash_init_chunk_handler (1 , 0 )
306- _ , _ , _ , tokens_per_expert = container .forward_backward (hidden_states )
307-
308- assert tokens_per_expert is not None
309- tokens_per_ep_rank = tokens_per_expert .sum ().item ()
310- over_budget_tensor = container .moe_layer .token_dispatcher .check_over_budget ()
311- over_budget = over_budget_tensor .item () if over_budget_tensor is not None else False
312-
313- # When tokens_per_ep_rank > budget, over_budget flag must be raised
314- if tokens_per_ep_rank >= budget :
315- assert over_budget , (
316- f"tokens_per_ep_rank ({ tokens_per_ep_rank } ) > budget ({ budget } ), "
317- "but over_budget flag was not set"
354+ _forward_backward_all_layers (container , hidden_states )
355+
356+ overflow = check_paged_stash_overflow ()
357+ num_layers = len (container .moe_layers )
358+ stash_cuda = container .config .stash_buffer_size_factor_cuda
359+ stash_cpu = container .config .stash_buffer_size_factor_cpu
360+ stash_buffer_size = num_tokens * num_layers * (stash_cuda + stash_cpu )
361+
362+ total_tokens = 0
363+ for layer_idx , layer in enumerate (container .moe_layers ):
364+ comm = getattr (layer .token_dispatcher , "_comm_manager" , None )
365+ routing_map = getattr (comm , "routing_map" , None ) if comm is not None else None
366+ over_budget_tensor = (
367+ layer .token_dispatcher .check_over_budget ()
368+ if hasattr (layer .token_dispatcher , "check_over_budget" )
369+ else None
318370 )
319- else :
320- assert not over_budget , (
321- f"tokens_per_ep_rank ({ tokens_per_ep_rank } ) <= budget ({ budget } ), "
322- "but over_budget flag was set"
371+ over_budget = over_budget_tensor .item () if over_budget_tensor is not None else False
372+
373+ assert routing_map is not None , f"layer { layer_idx } : routing_map is None"
374+ assert routing_map .dim () == 2 , f"layer { layer_idx } : expected 2D routing_map"
375+ assert routing_map .shape [1 ] == container .config .num_moe_experts , (
376+ f"layer { layer_idx } : routing_map has { routing_map .shape [1 ]} experts, "
377+ f"expected { container .config .num_moe_experts } "
378+ )
379+ tokens_per_expert_from_map = _tokens_per_expert_from_routing_map (routing_map , layer )
380+ tokens_per_expert_from_map_padded = _pad_token_counts_to_align_size (
381+ tokens_per_expert_from_map , pad_multiple
323382 )
383+ tokens_per_ep_rank_from_map = tokens_per_expert_from_map_padded .sum ().item ()
384+ total_tokens += tokens_per_ep_rank_from_map
385+
386+ # Padded map-derived tokens strictly over budget iff dispatcher reports over_budget
387+ if tokens_per_ep_rank_from_map > budget :
388+ assert over_budget , (
389+ f"layer { layer_idx } : tokens_per_ep_rank_from_map "
390+ f"({ tokens_per_ep_rank_from_map } ) > budget ({ budget } ), "
391+ f"but over_budget flag was not set"
392+ )
393+ else :
394+ assert not over_budget , (
395+ f"layer { layer_idx } : tokens_per_ep_rank_from_map "
396+ f"({ tokens_per_ep_rank_from_map } ) <= budget ({ budget } ), "
397+ f"but over_budget flag was set"
398+ )
399+
400+ overflow_set = overflow .any ().item ()
401+ stash_exceeded = total_tokens > stash_buffer_size
402+ assert overflow_set == stash_exceeded , (
403+ f"overflow { overflow_set } should match total_tokens > stash_buffer_size "
404+ f"({ total_tokens } > { stash_buffer_size } )"
405+ )
0 commit comments