Skip to content

Commit ba56947

Browse files
committed
Refactor the unit test for paged stashing
1 parent 741052b commit ba56947

File tree

1 file changed

+156
-74
lines changed

1 file changed

+156
-74
lines changed

tests/unit_tests/transformer/moe/test_paged_stashing.py

Lines changed: 156 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from megatron.core.transformer.moe.moe_utils import get_align_size_for_quantization
1212
from megatron.core.transformer.moe.experts import TEGroupedMLP
1313
from megatron.core.transformer.moe.paged_stash import (
14+
check_paged_stash_overflow,
1415
paged_stash_init_chunk_handler,
1516
paged_stash_reset,
1617
)
@@ -19,6 +20,34 @@
1920
from tests.unit_tests.test_utilities import Utils
2021

2122

23+
def _global_tokens_per_expert_from_local_routing_map(routing_map: torch.Tensor) -> torch.Tensor:
24+
"""Per-expert token counts from a local routing map, summed across the default process group.
25+
26+
``routing_map`` is shaped [num_local_token_rows, num_experts] (as in
27+
``_HybridEPManager``). Tests here assume world size equals expert-parallel size (all GPUs
28+
are EP ranks); ``all_reduce`` on the world group aggregates disjoint local maps.
29+
"""
30+
counts = routing_map.sum(dim=0).to(torch.int64)
31+
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
32+
torch.distributed.all_reduce(counts, op=torch.distributed.ReduceOp.SUM)
33+
return counts
34+
35+
36+
def _tokens_per_expert_from_routing_map(routing_map: torch.Tensor, layer: MoELayer) -> torch.Tensor:
37+
"""Per-local-expert assignment counts from the routing map (columns for this EP rank)."""
38+
counts = _global_tokens_per_expert_from_local_routing_map(routing_map)
39+
idx = torch.as_tensor(layer.local_expert_indices, device=counts.device, dtype=torch.long)
40+
return counts[idx].to(torch.int64).clone()
41+
42+
43+
def _pad_token_counts_to_align_size(
44+
tokens_per_expert: torch.Tensor, pad_multiple: int
45+
) -> torch.Tensor:
46+
"""Round each count up to a multiple of ``pad_multiple`` (``n + (-n % m)`` like budget)."""
47+
t = tokens_per_expert.to(torch.int64)
48+
return t + (-t % pad_multiple)
49+
50+
2251
class MoEModelTestContainer:
2352
def __init__(
2453
self,
@@ -92,12 +121,19 @@ def __init__(
92121
moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True),
93122
use_transformer_engine_op_fuser=kwargs.get("use_transformer_engine_op_fuser", False),
94123
moe_mlp_glu_interleave_size=kwargs.get("moe_mlp_glu_interleave_size", None),
95-
moe_router_padding_for_quantization=kwargs.get("moe_router_padding_for_quantization", False),
124+
moe_router_padding_for_quantization=kwargs.get(
125+
"moe_router_padding_for_quantization", False
126+
),
96127
gated_linear_unit=kwargs.get("gated_linear_unit", False),
97128
activation_func=kwargs.get("activation_func", F.gelu),
98129
moe_router_force_biased=kwargs.get("moe_router_force_biased", None),
130+
stash_buffer_size_factor_cuda=0.5,
131+
stash_buffer_size_factor_cpu=1.5,
99132
)
100-
self.moe_layer = self._create_moe_layer(layer_number=0)
133+
self.moe_layers = [
134+
self._create_moe_layer(layer_number=i) for i in range(num_layers)
135+
]
136+
self.moe_layer = self.moe_layers[0]
101137

102138
def _create_moe_layer(self, layer_number=0):
103139
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
@@ -114,43 +150,44 @@ def _create_moe_layer(self, layer_number=0):
114150
return moe_layer
115151

116152
def zero_grad(self):
117-
self.moe_layer.zero_grad()
153+
for layer in self.moe_layers:
154+
layer.zero_grad()
118155

119156
def __del__(self):
120157
torch.distributed.barrier()
121158
torch.cuda.synchronize()
122159
Utils.destroy_model_parallel()
123160

124-
def forward_backward(self, hidden_states):
125-
"""Run one forward and backward pass through the MoE layer.
126-
127-
Returns:
128-
output: MoE layer output (detached).
129-
hidden_states_grad: Gradient w.r.t. hidden_states.
130-
routing_map: Token-to-expert routing map from the dispatcher (after forward).
131-
tokens_per_expert: Number of tokens per local expert on this EP rank (after forward).
132-
"""
133-
hidden_states = hidden_states.cuda().requires_grad_(True)
134-
quantization_context = get_fp8_context(self.config)
135-
with quantization_context:
136-
output, _ = self.moe_layer(hidden_states)
137-
# Capture routing_map and tokens_per_expert after forward (before backward)
138-
comm = getattr(self.moe_layer.token_dispatcher, "_comm_manager", None)
139-
routing_map = getattr(comm, "routing_map", None)
140-
tokens_per_expert = (
141-
comm.get_number_of_tokens_per_expert()
142-
if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert")
143-
else None
144-
)
145-
# Use contiguous gradient to avoid non-contiguous grad in HybridEP combine backward
146-
# (output.sum().backward() produces a broadcast gradient that is non-contiguous)
147-
output.backward(torch.ones_like(output))
148-
return output.detach(), hidden_states.grad, routing_map, tokens_per_expert
149-
150161
def destroy(self):
151162
Utils.destroy_model_parallel()
152163

153164

165+
def _forward_backward_all_layers(container: MoEModelTestContainer, hidden_states: torch.Tensor):
166+
"""Forward/backward all MoE layers; returns output, input grad, last layer routing state."""
167+
initial_hidden_states = hidden_states.cuda().requires_grad_(True)
168+
hidden_states = initial_hidden_states
169+
quantization_context = get_fp8_context(container.config)
170+
with quantization_context:
171+
for layer in container.moe_layers:
172+
hidden_states, _ = layer(hidden_states)
173+
output = hidden_states
174+
last_layer = container.moe_layers[-1]
175+
comm = getattr(last_layer.token_dispatcher, "_comm_manager", None)
176+
routing_map = getattr(comm, "routing_map", None)
177+
tokens_per_expert = (
178+
comm.get_number_of_tokens_per_expert()
179+
if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert")
180+
else None
181+
)
182+
output.backward(torch.ones_like(output))
183+
return (
184+
output.detach(),
185+
initial_hidden_states.grad,
186+
routing_map,
187+
tokens_per_expert,
188+
)
189+
190+
154191
def is_hybrid_ep_available():
155192
from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP
156193
return HAVE_HYBRIDEP
@@ -166,7 +203,8 @@ def teardown_method(self, method):
166203

167204
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
168205
@pytest.mark.internal
169-
def test_forward_backward(self):
206+
def test_forward_backward_4_layers(self):
207+
"""Test paged stashing with 4 MoE layers: ref run vs paged run match."""
170208
if not is_hybrid_ep_available():
171209
pytest.skip("Hybrid EP is not available")
172210

@@ -177,7 +215,7 @@ def test_forward_backward(self):
177215
ep_size=4,
178216
pp_size=1,
179217
num_moe_experts=8,
180-
num_layers=2,
218+
num_layers=4,
181219
moe_router_topk=2,
182220
moe_router_load_balancing_type="aux_loss",
183221
moe_token_dispatcher_type="flex",
@@ -197,11 +235,12 @@ def test_forward_backward(self):
197235
gated_linear_unit=True,
198236
activation_func=F.silu,
199237
)
200-
if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported():
238+
experts = container.moe_layer.experts
239+
fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported()
240+
if not fused_ok:
201241
container.destroy()
202242
pytest.skip("TEGroupedMLP fused impl not supported")
203243

204-
# [sequence_length, batch_size, hidden_size] for MoELayer.forward
205244
seq_length = 1024
206245
batch_size = 1
207246
hidden_size = container.config.hidden_size
@@ -210,32 +249,42 @@ def test_forward_backward(self):
210249
)
211250

212251
# First iteration: capture schedule, capacity, etc.
213-
paged_stash_reset(True)
252+
paged_stash_reset(True, config=container.config)
214253
paged_stash_init_chunk_handler(1, 0)
215254
output_ref, hidden_states_grad_ref, routing_map_ref, tokens_per_expert_ref = (
216-
container.forward_backward(hidden_states)
255+
_forward_backward_all_layers(container, hidden_states)
217256
)
218257

219258
container.zero_grad()
220259

221260
# Second iteration: run with paged stash.
222-
paged_stash_reset(True)
261+
paged_stash_reset(True, config=container.config)
223262
paged_stash_init_chunk_handler(1, 0)
224-
output, hidden_states_grad, routing_map, tokens_per_expert = container.forward_backward(
225-
hidden_states
263+
output, hidden_states_grad, routing_map, tokens_per_expert = _forward_backward_all_layers(
264+
container, hidden_states
226265
)
227266

228-
# Verify output and input gradient match the first iteration.
229-
torch.testing.assert_close(output, output_ref, atol=1e-4, rtol=1e-4)
230-
torch.testing.assert_close(
231-
hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4
267+
overflow = check_paged_stash_overflow()
268+
assert overflow.any().item() == 0
269+
270+
assert torch.allclose(output, output_ref, atol=1e-4, rtol=1e-4), (
271+
f"output != output_ref: max diff = {(output - output_ref).abs().max().item()}"
272+
)
273+
assert torch.allclose(hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4), (
274+
f"hidden_states_grad != ref: max diff = "
275+
f"{(hidden_states_grad - hidden_states_grad_ref).abs().max().item()}"
232276
)
233-
# Routing and token counts available after forward (e.g. for debugging or further checks)
234277
if routing_map is not None and tokens_per_expert is not None:
235278
num_tokens_per_ep_rank = tokens_per_expert.sum().item()
236-
assert num_tokens_per_ep_rank > 0
279+
assert num_tokens_per_ep_rank > 0, (
280+
f"num_tokens_per_ep_rank={num_tokens_per_ep_rank} (expected > 0)"
281+
)
237282
assert routing_map_ref is not None and tokens_per_expert_ref is not None
238-
torch.testing.assert_close(tokens_per_expert, tokens_per_expert_ref)
283+
tpe_f = tokens_per_expert.float()
284+
ref_f = tokens_per_expert_ref.float()
285+
assert torch.allclose(tpe_f, ref_f, atol=1e-4, rtol=1e-4), (
286+
f"tokens_per_expert != ref: max diff = {(tpe_f - ref_f).abs().max().item()}"
287+
)
239288

240289

241290
@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available")
@@ -249,8 +298,7 @@ def teardown_method(self, method):
249298
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
250299
@pytest.mark.internal
251300
def test_overload_factor_and_over_budget(self):
252-
"""Test budget computation (same as token_dispatcher lines 1017-1025) and assert
253-
over_budget flag is set when tokens_per_ep_rank exceeds budget."""
301+
"""Budget matches HybridEP setup_metadata; over_budget matches map-derived load."""
254302
if not is_hybrid_ep_available():
255303
pytest.skip("Hybrid EP is not available")
256304

@@ -261,8 +309,8 @@ def test_overload_factor_and_over_budget(self):
261309
ep_size=4,
262310
pp_size=1,
263311
num_moe_experts=8,
264-
num_layers=1,
265-
moe_router_topk=4,
312+
num_layers=4,
313+
moe_router_topk=2,
266314
moe_router_load_balancing_type="aux_loss",
267315
moe_token_dispatcher_type="flex",
268316
moe_permute_fusion=True,
@@ -274,50 +322,84 @@ def test_overload_factor_and_over_budget(self):
274322
moe_use_legacy_grouped_gemm=False,
275323
moe_paged_stash=True,
276324
stash_modules=["expert_fc1", "moe_act", "expert_fc2"],
277-
moe_expert_rank_capacity_factor=1.0,
325+
moe_expert_rank_capacity_factor=1.5,
278326
use_transformer_engine_op_fuser=True,
279327
moe_mlp_glu_interleave_size=32,
280328
moe_router_padding_for_quantization=True,
281329
gated_linear_unit=True,
282330
activation_func=F.silu,
283331
moe_router_force_biased=1,
284332
)
285-
if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported():
333+
experts = container.moe_layer.experts
334+
fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported()
335+
if not fused_ok:
286336
container.destroy()
287337
pytest.skip("TEGroupedMLP fused impl not supported")
288338

289-
seq_length = 4096
339+
seq_length = 1024
290340
batch_size = 1
291341
topk = container.config.moe_router_topk
292342
capacity_factor = container.config.moe_expert_rank_capacity_factor
293-
hidden_size = container.config.hidden_size
294343
hidden_states = torch.randn(
295-
(seq_length, batch_size, hidden_size), dtype=torch.bfloat16
344+
(seq_length, batch_size, container.config.hidden_size), dtype=torch.bfloat16
296345
)
297346

298-
# Budget computed like token_dispatcher._HybridEPManager.setup_metadata (lines 1017-1025)
299-
num_tokens = seq_length * batch_size
347+
num_tokens = seq_length * batch_size * topk
300348
pad_multiple = get_align_size_for_quantization(container.config)
301-
budget = int(num_tokens * topk * capacity_factor)
349+
budget = int(num_tokens * capacity_factor)
302350
budget += -budget % pad_multiple
303351

304-
paged_stash_reset(True)
352+
paged_stash_reset(True, config=container.config)
305353
paged_stash_init_chunk_handler(1, 0)
306-
_, _, _, tokens_per_expert = container.forward_backward(hidden_states)
307-
308-
assert tokens_per_expert is not None
309-
tokens_per_ep_rank = tokens_per_expert.sum().item()
310-
over_budget_tensor = container.moe_layer.token_dispatcher.check_over_budget()
311-
over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False
312-
313-
# When tokens_per_ep_rank > budget, over_budget flag must be raised
314-
if tokens_per_ep_rank >= budget:
315-
assert over_budget, (
316-
f"tokens_per_ep_rank ({tokens_per_ep_rank}) > budget ({budget}), "
317-
"but over_budget flag was not set"
354+
_forward_backward_all_layers(container, hidden_states)
355+
356+
overflow = check_paged_stash_overflow()
357+
num_layers = len(container.moe_layers)
358+
stash_cuda = container.config.stash_buffer_size_factor_cuda
359+
stash_cpu = container.config.stash_buffer_size_factor_cpu
360+
stash_buffer_size = num_tokens * num_layers * (stash_cuda + stash_cpu)
361+
362+
total_tokens = 0
363+
for layer_idx, layer in enumerate(container.moe_layers):
364+
comm = getattr(layer.token_dispatcher, "_comm_manager", None)
365+
routing_map = getattr(comm, "routing_map", None) if comm is not None else None
366+
over_budget_tensor = (
367+
layer.token_dispatcher.check_over_budget()
368+
if hasattr(layer.token_dispatcher, "check_over_budget")
369+
else None
318370
)
319-
else:
320-
assert not over_budget, (
321-
f"tokens_per_ep_rank ({tokens_per_ep_rank}) <= budget ({budget}), "
322-
"but over_budget flag was set"
371+
over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False
372+
373+
assert routing_map is not None, f"layer {layer_idx}: routing_map is None"
374+
assert routing_map.dim() == 2, f"layer {layer_idx}: expected 2D routing_map"
375+
assert routing_map.shape[1] == container.config.num_moe_experts, (
376+
f"layer {layer_idx}: routing_map has {routing_map.shape[1]} experts, "
377+
f"expected {container.config.num_moe_experts}"
378+
)
379+
tokens_per_expert_from_map = _tokens_per_expert_from_routing_map(routing_map, layer)
380+
tokens_per_expert_from_map_padded = _pad_token_counts_to_align_size(
381+
tokens_per_expert_from_map, pad_multiple
323382
)
383+
tokens_per_ep_rank_from_map = tokens_per_expert_from_map_padded.sum().item()
384+
total_tokens += tokens_per_ep_rank_from_map
385+
386+
# Padded map-derived tokens strictly over budget iff dispatcher reports over_budget
387+
if tokens_per_ep_rank_from_map > budget:
388+
assert over_budget, (
389+
f"layer {layer_idx}: tokens_per_ep_rank_from_map "
390+
f"({tokens_per_ep_rank_from_map}) > budget ({budget}), "
391+
f"but over_budget flag was not set"
392+
)
393+
else:
394+
assert not over_budget, (
395+
f"layer {layer_idx}: tokens_per_ep_rank_from_map "
396+
f"({tokens_per_ep_rank_from_map}) <= budget ({budget}), "
397+
f"but over_budget flag was set"
398+
)
399+
400+
overflow_set = overflow.any().item()
401+
stash_exceeded = total_tokens > stash_buffer_size
402+
assert overflow_set == stash_exceeded, (
403+
f"overflow {overflow_set} should match total_tokens > stash_buffer_size "
404+
f"({total_tokens} > {stash_buffer_size})"
405+
)

0 commit comments

Comments
 (0)