Skip to content

Commit 8f2b685

Browse files
authored
fix: EP collective deadlock with variable-length token counts (#1365)
* fix: EP collective deadlock with variable-length token counts DTensor.from_local(x, [Shard(0)]).full_tensor() assumes uniform token counts across ranks. Variable-length sequence packing breaks this assumption, causing NCCL deadlocks when ranks disagree on buffer sizes. Implemented fixes until successful training run: - Stage 1: pad+all_gather+trim for variable-length input gather - Stage 2: all_reduce+narrow for correct per-rank output boundaries - Stage 3: gradient anchor (y + x*0.0) so all ranks enter backward collectives regardless of routing Signed-off-by: David Yang <56007659+ShiftyBlock@users.noreply.github.com> * refactor: use ep_mesh.get_group() directly Verified training still works: https://pastebin.com/k630jqSM Signed-off-by: David Yang <56007659+ShiftyBlock@users.noreply.github.com> --------- Signed-off-by: David Yang <56007659+ShiftyBlock@users.noreply.github.com>
1 parent d50e20b commit 8f2b685

File tree

1 file changed

+81
-14
lines changed

1 file changed

+81
-14
lines changed

nemo_automodel/components/moe/experts.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
from typing import TYPE_CHECKING, Any, Dict, Optional
1717

1818
import torch
19+
import torch.distributed as dist
20+
import torch.distributed.nn.functional as dist_nn_f
1921
import torch.nn as nn
2022
import torch.nn.functional as F
23+
from torch.autograd import Function
2124
from torch.distributed.device_mesh import DeviceMesh
22-
from torch.distributed.tensor import DTensor, Partial, Shard
25+
from torch.distributed.tensor import DTensor
2326

2427
from nemo_automodel.components.moe.state_dict_utils import create_dtensor_from_local
2528

@@ -34,6 +37,46 @@
3437
)
3538
from nemo_automodel.components.moe.megatron.token_dispatcher import MoEFlexTokenDispatcher, TokenDispatcherConfig
3639

40+
# ── EP variable-length collective helpers ──
41+
42+
43+
class _AllGatherConcatVarlenFn(Function):
44+
"""All-gather with variable local lengths and autograd-safe backward.
45+
46+
Backward uses all-reduce + local narrow instead of reduce-scatter to avoid
47+
monitoredBarrier deadlocks observed with mixed FSDP/EP backward collective ordering.
48+
"""
49+
50+
@staticmethod
51+
def forward(ctx, local_tensor: torch.Tensor, group: dist.ProcessGroup, gathered_lens: list[int], max_len: int):
52+
local_len = local_tensor.size(0)
53+
if local_len < max_len:
54+
pad_shape = (max_len - local_len,) + tuple(local_tensor.shape[1:])
55+
pad = torch.zeros(pad_shape, dtype=local_tensor.dtype, device=local_tensor.device)
56+
local_padded = torch.cat([local_tensor, pad], dim=0)
57+
else:
58+
local_padded = local_tensor
59+
60+
world_size = len(gathered_lens)
61+
gathered = [torch.empty_like(local_padded) for _ in range(world_size)]
62+
dist.all_gather(gathered, local_padded, group=group)
63+
gathered = [g[:n] for g, n in zip(gathered, gathered_lens)]
64+
65+
ctx.group = group
66+
ctx.gathered_lens = gathered_lens
67+
ctx.rank = dist.get_rank(group)
68+
return torch.cat(gathered, dim=0)
69+
70+
@staticmethod
71+
def backward(ctx, grad_output: torch.Tensor):
72+
grad_full = grad_output.contiguous()
73+
start = sum(ctx.gathered_lens[: ctx.rank])
74+
local_len = ctx.gathered_lens[ctx.rank]
75+
dist.all_reduce(grad_full, op=dist.ReduceOp.SUM, group=ctx.group)
76+
grad_local = grad_full.narrow(0, start, local_len).contiguous()
77+
return grad_local, None, None, None
78+
79+
3780
if TYPE_CHECKING:
3881
from transformer_engine.pytorch import GroupedLinear
3982

@@ -253,18 +296,36 @@ def forward(
253296
)
254297
down_projs = self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs
255298

256-
# DTensor all-gather/reduce-scatter for expert parallelism
299+
# EP variable-length all-gather
257300
if ep_size > 1:
258-
# grad_placements=[Partial()] ensures backward does reduce-scatter
259-
# (default Replicate would just slice, losing cross-rank gradient contributions)
260-
x = DTensor.from_local(x, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor(
261-
grad_placements=[Partial()]
262-
)
263-
weights = DTensor.from_local(weights.float(), device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor(
264-
grad_placements=[Partial()]
265-
)
266-
indices = DTensor.from_local(indices, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor()
267-
token_mask = DTensor.from_local(token_mask, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor()
301+
ep_group = ep_mesh.get_group()
302+
local_num_tokens = x.size(0)
303+
304+
# Exchange per-rank token counts
305+
local_len_t = torch.tensor([local_num_tokens], device=x.device, dtype=torch.int64)
306+
gathered_len_t = [torch.zeros_like(local_len_t) for _ in range(ep_size)]
307+
dist.all_gather(gathered_len_t, local_len_t, group=ep_group)
308+
gathered_lens = [int(t.item()) for t in gathered_len_t]
309+
max_len = max(gathered_lens)
310+
311+
def _all_gather_dim0_var(local_tensor: torch.Tensor, *, differentiable: bool) -> torch.Tensor:
312+
if differentiable:
313+
return _AllGatherConcatVarlenFn.apply(local_tensor, ep_group, gathered_lens, max_len)
314+
if max_len > local_tensor.size(0):
315+
pad_shape = (max_len - local_tensor.size(0),) + tuple(local_tensor.shape[1:])
316+
pad = torch.zeros(pad_shape, dtype=local_tensor.dtype, device=local_tensor.device)
317+
local_padded = torch.cat([local_tensor, pad], dim=0)
318+
else:
319+
local_padded = local_tensor
320+
gathered = [torch.empty_like(local_padded) for _ in range(ep_size)]
321+
dist.all_gather(gathered, local_padded, group=ep_group)
322+
gathered = [g[:n] for g, n in zip(gathered, gathered_lens)]
323+
return torch.cat(gathered, dim=0)
324+
325+
x = _all_gather_dim0_var(x, differentiable=True)
326+
weights = _all_gather_dim0_var(weights.float(), differentiable=False)
327+
indices = _all_gather_dim0_var(indices, differentiable=False)
328+
token_mask = _all_gather_dim0_var(token_mask, differentiable=False)
268329

269330
n_local_experts = self.n_routed_experts // ep_size
270331
experts_start_idx = ep_rank * n_local_experts
@@ -294,9 +355,15 @@ def forward(
294355
experts_end_idx,
295356
)
296357

358+
# Gradient anchor
359+
if ep_size > 1:
360+
y = y + (x * 0.0)
361+
362+
# Variable-length reduce: all_reduce + narrow to original per-rank token boundaries
297363
if ep_size > 1:
298-
y = DTensor.from_local(y, device_mesh=ep_mesh, placements=[Partial()])
299-
y = y.redistribute(placements=[Shard(0)]).to_local()
364+
y = dist_nn_f.all_reduce(y, op=dist.ReduceOp.SUM, group=ep_group)
365+
start = sum(gathered_lens[:ep_rank])
366+
y = y.narrow(0, start, local_num_tokens).contiguous()
300367

301368
return y.to(input_dtype)
302369

0 commit comments

Comments
 (0)