Skip to content

Commit 8588de5

Browse files
committed
Added empty expert case handling in GroupedExpertsLoRA
Signed-off-by: Yuhe Zhang <[email protected]>
1 parent bcad3a9 commit 8588de5

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

nemo_automodel/components/_peft/lora_moe.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,14 @@ def compute_lora(x, lora_A, lora_B, expert_id):
190190

191191
y = torch.zeros_like(x)
192192

193+
active_local_experts = 0
193194
for i in range(experts_start_idx, experts_end_idx):
194195
indices_mask = torch.logical_and(indices == i, token_mask.unsqueeze(-1))
195196
idx, top = torch.where(indices_mask)
196197

197198
if idx.numel() == 0:
198199
continue
200+
active_local_experts += 1
199201

200202
gate_and_up_proj = get_local_proj(self.gate_and_up_projs, i)
201203
down_proj = get_local_proj(self.down_projs, i)
@@ -244,7 +246,44 @@ def compute_lora(x, lora_A, lora_B, expert_id):
244246

245247
expert_out = expert_out_val * weights[idx, top, None]
246248

247-
y = torch.scatter_add(y, dim=0, index=idx_b, src=expert_out.to(x.dtype))
249+
y.scatter_add_(dim=0, index=idx_b, src=expert_out.to(x.dtype))
250+
251+
# Handle the edge case where no tokens are routed to any local experts.
252+
# This ensures gradient flow through local expert parameters during backprop
253+
# and proper participation in collective operations (reduce-scatter).
254+
if active_local_experts == 0:
255+
gate_and_up_proj = get_local_proj(self.gate_and_up_projs, experts_start_idx)
256+
down_proj = get_local_proj(self.down_projs, experts_start_idx)
257+
258+
dummy_x = torch.zeros_like(x[0]).unsqueeze(0)
259+
260+
# Gate + Up with LoRA
261+
gate_and_up_out = dummy_x @ gate_and_up_proj
262+
gate_and_up_out = gate_and_up_out + compute_lora(
263+
dummy_x, self.lora_gate_and_up_A, self.lora_gate_and_up_B, experts_start_idx
264+
)
265+
266+
# Activation
267+
if self.config.expert_activation == "swiglu":
268+
gate_out, up_out = torch.chunk(gate_and_up_out, 2, -1)
269+
inter = torch.nn.functional.silu(gate_out) * up_out
270+
elif self.config.expert_activation == "quick_geglu":
271+
gate_out, up_out = gate_and_up_out[..., ::2], gate_and_up_out[..., 1::2]
272+
limit = self.config.activation_limit
273+
alpha = self.config.activation_alpha
274+
gate_out = gate_out.clamp(min=None, max=limit)
275+
up_out = up_out.clamp(min=-limit, max=limit)
276+
out_glu = gate_out * torch.sigmoid(alpha * gate_out)
277+
inter = out_glu * (up_out + 1)
278+
279+
# Down with LoRA
280+
expert_out_val = inter @ down_proj
281+
expert_out_val = expert_out_val + compute_lora(
282+
inter, self.lora_down_A, self.lora_down_B, experts_start_idx
283+
)
284+
285+
expert_out = expert_out_val * weights[0, 0, None]
286+
y[0] += expert_out[0]
248287

249288
if ep_size > 1:
250289
y = DTensor.from_local(y, device_mesh=ep_mesh, placements=[Partial()])

0 commit comments

Comments
 (0)