@@ -190,12 +190,14 @@ def compute_lora(x, lora_A, lora_B, expert_id):
190190
191191 y = torch .zeros_like (x )
192192
193+ active_local_experts = 0
193194 for i in range (experts_start_idx , experts_end_idx ):
194195 indices_mask = torch .logical_and (indices == i , token_mask .unsqueeze (- 1 ))
195196 idx , top = torch .where (indices_mask )
196197
197198 if idx .numel () == 0 :
198199 continue
200+ active_local_experts += 1
199201
200202 gate_and_up_proj = get_local_proj (self .gate_and_up_projs , i )
201203 down_proj = get_local_proj (self .down_projs , i )
@@ -244,7 +246,44 @@ def compute_lora(x, lora_A, lora_B, expert_id):
244246
245247 expert_out = expert_out_val * weights [idx , top , None ]
246248
247- y = torch .scatter_add (y , dim = 0 , index = idx_b , src = expert_out .to (x .dtype ))
249+ y .scatter_add_ (dim = 0 , index = idx_b , src = expert_out .to (x .dtype ))
250+
251+ # Handle the edge case where no tokens are routed to any local experts.
252+ # This ensures gradient flow through local expert parameters during backprop
253+ # and proper participation in collective operations (reduce-scatter).
254+ if active_local_experts == 0 :
255+ gate_and_up_proj = get_local_proj (self .gate_and_up_projs , experts_start_idx )
256+ down_proj = get_local_proj (self .down_projs , experts_start_idx )
257+
258+ dummy_x = torch .zeros_like (x [0 ]).unsqueeze (0 )
259+
260+ # Gate + Up with LoRA
261+ gate_and_up_out = dummy_x @ gate_and_up_proj
262+ gate_and_up_out = gate_and_up_out + compute_lora (
263+ dummy_x , self .lora_gate_and_up_A , self .lora_gate_and_up_B , experts_start_idx
264+ )
265+
266+ # Activation
267+ if self .config .expert_activation == "swiglu" :
268+ gate_out , up_out = torch .chunk (gate_and_up_out , 2 , - 1 )
269+ inter = torch .nn .functional .silu (gate_out ) * up_out
270+ elif self .config .expert_activation == "quick_geglu" :
271+ gate_out , up_out = gate_and_up_out [..., ::2 ], gate_and_up_out [..., 1 ::2 ]
272+ limit = self .config .activation_limit
273+ alpha = self .config .activation_alpha
274+ gate_out = gate_out .clamp (min = None , max = limit )
275+ up_out = up_out .clamp (min = - limit , max = limit )
276+ out_glu = gate_out * torch .sigmoid (alpha * gate_out )
277+ inter = out_glu * (up_out + 1 )
278+
279+ # Down with LoRA
280+ expert_out_val = inter @ down_proj
281+ expert_out_val = expert_out_val + compute_lora (
282+ inter , self .lora_down_A , self .lora_down_B , experts_start_idx
283+ )
284+
285+ expert_out = expert_out_val * weights [0 , 0 , None ]
286+ y [0 ] += expert_out [0 ]
248287
249288 if ep_size > 1 :
250289 y = DTensor .from_local (y , device_mesh = ep_mesh , placements = [Partial ()])
0 commit comments