Skip to content

Commit 135bc45

Browse files
authored
fix: workaround duplicated AllGather for EP+FSDP2 (#173)
### What does this PR do? Compute shared expert first to workaround the duplicated all-gather issue in EP+FSDP2, which seems to be a bug in PyTorch FSDP2. Before: <img width="2230" height="302" alt="76591" src="https://github.com/user-attachments/assets/f9f4e553-5678-4fa8-9fcf-77750ad165bf" /> After: <img width="1756" height="188" alt="4480" src="https://github.com/user-attachments/assets/749b572c-214b-4121-bfad-5ec2cea7f191" />
1 parent e020de6 commit 135bc45

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

veomni/models/transformers/deepseek_v3/modeling_deepseek.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def forward(self, hidden_states):
370370
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
371371
topk_weight = topk_weight / denominator
372372
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
373+
# Ensure routing weights keep the same dtype as hidden states so checkpoint recomputations stay consistent.
374+
topk_weight = topk_weight.to(hidden_states.dtype)
373375

374376
return topk_idx, topk_weight
375377

@@ -469,13 +471,19 @@ def forward(self, hidden_states):
469471
topk_idx, topk_weight = self.gate(hidden_states)
470472
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
471473

472-
final_hidden_states = torch.zeros(
473-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
474-
)
475-
final_hidden_states = self.experts(hidden_states, routing_weights=topk_weight, selected_experts=topk_idx)
476-
474+
# we compute shared expert first to workaround the duplicated AllGather issue when using EP+FSDP2
475+
# which seems to be a bug of fsdp2 fully_shard
477476
if self.config.n_shared_experts is not None:
478-
final_hidden_states = final_hidden_states + self.shared_experts(identity)
477+
shared_expert_outputs = self.shared_experts(identity)
478+
else:
479+
shared_expert_outputs = None
480+
481+
moe_outputs = self.experts(hidden_states, routing_weights=topk_weight, selected_experts=topk_idx)
482+
483+
if shared_expert_outputs is not None:
484+
final_hidden_states = moe_outputs + shared_expert_outputs
485+
else:
486+
final_hidden_states = moe_outputs
479487

480488
return final_hidden_states
481489

0 commit comments

Comments
 (0)