Skip to content

Commit 6c38eae

Browse files
authored
fix: handle zero active experts for 1 ep rank in GroupedExperts (#935)
Signed-off-by: Hemil Desai <[email protected]>
1 parent c05cb81 commit 6c38eae

File tree

3 files changed

+474
-0
lines changed

3 files changed

+474
-0
lines changed

nemo_automodel/components/moe/layers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,29 @@ def get_local_proj(proj, expert_id):
303303

304304
y.scatter_add_(dim=0, index=idx_b, src=expert_out.to(x.dtype))
305305

306+
# Handle the edge case where no tokens are routed to any local experts.
307+
# This can occur during expert parallelism when all tokens on a particular
308+
# rank happen to select experts hosted on other ranks. We perform a dummy
309+
# computation through the local expert weights to ensure:
310+
# 1. Gradient flow through local expert parameters during backpropagation
311+
# 2. Proper participation in collective operations (reduce-scatter)
312+
# The computation is a no-op since we multiply by zero (using zeros_like input).
313+
if active_local_experts == 0:
314+
gate_and_up_proj = get_local_proj(self.gate_and_up_projs, experts_start_idx)
315+
down_proj = get_local_proj(self.down_projs, experts_start_idx)
316+
gate_up_proj_bias = get_local_proj(self.gate_up_proj_bias, experts_start_idx) if self.expert_bias else None
317+
down_proj_bias = get_local_proj(self.down_proj_bias, experts_start_idx) if self.expert_bias else None
318+
319+
expert_out = (
320+
self.expert_activation(
321+
torch.zeros_like(x[0]).unsqueeze(0),
322+
gate_and_up_proj=gate_and_up_proj,
323+
down_proj=down_proj,
324+
)
325+
* weights[0, 0, None]
326+
)
327+
y[0] += expert_out[0]
328+
306329
if ep_size > 1:
307330
y = DTensor.from_local(y, device_mesh=ep_mesh, placements=[Partial()])
308331
y = y.redistribute(placements=[Shard(0)]).to_local()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
17+
import torch
18+
import torch.nn.functional as F
19+
20+
from nemo_automodel.components.moe.layers import GroupedExperts, MoEConfig
21+
22+
# Track whether expert_activation was called
23+
activation_called = [False]
24+
25+
26+
def tracking_swiglu(x, *, gate_and_up_proj, down_proj, gate_up_proj_bias=None, down_proj_bias=None):
27+
"""Tracking version of swiglu that sets activation_called[0] = True."""
28+
global activation_called
29+
activation_called[0] = True
30+
gate_and_up_out = x @ gate_and_up_proj
31+
if gate_up_proj_bias is not None:
32+
gate_and_up_out = gate_and_up_out + gate_up_proj_bias
33+
gate_out, up_out = torch.chunk(gate_and_up_out, 2, -1)
34+
inter = F.silu(gate_out) * up_out
35+
inter = inter @ down_proj
36+
if down_proj_bias is not None:
37+
inter = inter + down_proj_bias
38+
return inter
39+
40+
41+
def main(device_str: str = "cuda:0") -> int:
42+
"""
43+
Run the zero active experts gradient test.
44+
45+
Args:
46+
device_str: Device to run on ("cuda:0" or "cpu")
47+
48+
Returns:
49+
0 if test passed, 1 if test failed
50+
"""
51+
# Use global activation_called to track across function boundaries
52+
global activation_called
53+
activation_called[0] = False # Reset at start
54+
55+
moe_config = MoEConfig(
56+
n_routed_experts=8,
57+
n_shared_experts=2,
58+
n_activated_experts=2,
59+
n_expert_groups=1,
60+
n_limited_groups=1,
61+
train_gate=True,
62+
gate_bias_update_factor=0.1,
63+
aux_loss_coeff=0.01,
64+
score_func="softmax",
65+
route_scale=1.0,
66+
dim=128,
67+
inter_dim=256,
68+
moe_inter_dim=256,
69+
norm_topk_prob=False,
70+
router_bias=False,
71+
expert_bias=False,
72+
expert_activation="swiglu",
73+
activation_alpha=1.702,
74+
activation_limit=7.0,
75+
dtype=torch.float32,
76+
)
77+
78+
device = torch.device(device_str)
79+
experts = GroupedExperts(moe_config)
80+
experts.expert_activation = tracking_swiglu
81+
experts = experts.to(device)
82+
83+
with torch.no_grad():
84+
experts.gate_and_up_projs.normal_(0, 0.02)
85+
experts.down_projs.normal_(0, 0.02)
86+
87+
num_tokens = 8
88+
x = torch.randn(num_tokens, moe_config.dim, dtype=torch.float32, device=device)
89+
token_mask = torch.ones(num_tokens, dtype=torch.bool, device=device)
90+
weights = torch.rand(num_tokens, moe_config.n_activated_experts, dtype=torch.float32, device=device)
91+
92+
# Set indices to non-existent expert (simulates all tokens routed elsewhere)
93+
indices = torch.full(
94+
(num_tokens, moe_config.n_activated_experts),
95+
fill_value=moe_config.n_routed_experts + 100,
96+
dtype=torch.long,
97+
device=device,
98+
)
99+
100+
output = experts.forward(x, token_mask, weights, indices)
101+
102+
if activation_called[0]:
103+
print("SUCCESS: expert_activation was called even when no tokens select any expert")
104+
return 0
105+
else:
106+
print("FAIL: expert_activation was NOT called - the zero active experts fix is missing or broken")
107+
return 1
108+
109+
110+
if __name__ == "__main__":
111+
device = sys.argv[1] if len(sys.argv) > 1 else "cuda:0"
112+
sys.exit(main(device))

0 commit comments

Comments
 (0)