|
16 | 16 | from typing import TYPE_CHECKING, Any, Dict, Optional |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +import torch.distributed as dist |
| 20 | +import torch.distributed.nn.functional as dist_nn_f |
19 | 21 | import torch.nn as nn |
20 | 22 | import torch.nn.functional as F |
| 23 | +from torch.autograd import Function |
21 | 24 | from torch.distributed.device_mesh import DeviceMesh |
22 | | -from torch.distributed.tensor import DTensor, Partial, Shard |
| 25 | +from torch.distributed.tensor import DTensor |
23 | 26 |
|
24 | 27 | from nemo_automodel.components.moe.state_dict_utils import create_dtensor_from_local |
25 | 28 |
|
|
34 | 37 | ) |
35 | 38 | from nemo_automodel.components.moe.megatron.token_dispatcher import MoEFlexTokenDispatcher, TokenDispatcherConfig |
36 | 39 |
|
| 40 | +# ── EP variable-length collective helpers ── |
| 41 | + |
| 42 | + |
| 43 | +class _AllGatherConcatVarlenFn(Function): |
| 44 | + """All-gather with variable local lengths and autograd-safe backward. |
| 45 | +
|
| 46 | + Backward uses all-reduce + local narrow instead of reduce-scatter to avoid |
| 47 | + monitoredBarrier deadlocks observed with mixed FSDP/EP backward collective ordering. |
| 48 | + """ |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def forward(ctx, local_tensor: torch.Tensor, group: dist.ProcessGroup, gathered_lens: list[int], max_len: int): |
| 52 | + local_len = local_tensor.size(0) |
| 53 | + if local_len < max_len: |
| 54 | + pad_shape = (max_len - local_len,) + tuple(local_tensor.shape[1:]) |
| 55 | + pad = torch.zeros(pad_shape, dtype=local_tensor.dtype, device=local_tensor.device) |
| 56 | + local_padded = torch.cat([local_tensor, pad], dim=0) |
| 57 | + else: |
| 58 | + local_padded = local_tensor |
| 59 | + |
| 60 | + world_size = len(gathered_lens) |
| 61 | + gathered = [torch.empty_like(local_padded) for _ in range(world_size)] |
| 62 | + dist.all_gather(gathered, local_padded, group=group) |
| 63 | + gathered = [g[:n] for g, n in zip(gathered, gathered_lens)] |
| 64 | + |
| 65 | + ctx.group = group |
| 66 | + ctx.gathered_lens = gathered_lens |
| 67 | + ctx.rank = dist.get_rank(group) |
| 68 | + return torch.cat(gathered, dim=0) |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + def backward(ctx, grad_output: torch.Tensor): |
| 72 | + grad_full = grad_output.contiguous() |
| 73 | + start = sum(ctx.gathered_lens[: ctx.rank]) |
| 74 | + local_len = ctx.gathered_lens[ctx.rank] |
| 75 | + dist.all_reduce(grad_full, op=dist.ReduceOp.SUM, group=ctx.group) |
| 76 | + grad_local = grad_full.narrow(0, start, local_len).contiguous() |
| 77 | + return grad_local, None, None, None |
| 78 | + |
| 79 | + |
37 | 80 | if TYPE_CHECKING: |
38 | 81 | from transformer_engine.pytorch import GroupedLinear |
39 | 82 |
|
@@ -253,18 +296,36 @@ def forward( |
253 | 296 | ) |
254 | 297 | down_projs = self.down_projs.to_local() if isinstance(self.down_projs, DTensor) else self.down_projs |
255 | 298 |
|
256 | | - # DTensor all-gather/reduce-scatter for expert parallelism |
| 299 | + # EP variable-length all-gather |
257 | 300 | if ep_size > 1: |
258 | | - # grad_placements=[Partial()] ensures backward does reduce-scatter |
259 | | - # (default Replicate would just slice, losing cross-rank gradient contributions) |
260 | | - x = DTensor.from_local(x, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor( |
261 | | - grad_placements=[Partial()] |
262 | | - ) |
263 | | - weights = DTensor.from_local(weights.float(), device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor( |
264 | | - grad_placements=[Partial()] |
265 | | - ) |
266 | | - indices = DTensor.from_local(indices, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor() |
267 | | - token_mask = DTensor.from_local(token_mask, device_mesh=ep_mesh, placements=[Shard(0)]).full_tensor() |
| 301 | + ep_group = ep_mesh.get_group() |
| 302 | + local_num_tokens = x.size(0) |
| 303 | + |
| 304 | + # Exchange per-rank token counts |
| 305 | + local_len_t = torch.tensor([local_num_tokens], device=x.device, dtype=torch.int64) |
| 306 | + gathered_len_t = [torch.zeros_like(local_len_t) for _ in range(ep_size)] |
| 307 | + dist.all_gather(gathered_len_t, local_len_t, group=ep_group) |
| 308 | + gathered_lens = [int(t.item()) for t in gathered_len_t] |
| 309 | + max_len = max(gathered_lens) |
| 310 | + |
| 311 | + def _all_gather_dim0_var(local_tensor: torch.Tensor, *, differentiable: bool) -> torch.Tensor: |
| 312 | + if differentiable: |
| 313 | + return _AllGatherConcatVarlenFn.apply(local_tensor, ep_group, gathered_lens, max_len) |
| 314 | + if max_len > local_tensor.size(0): |
| 315 | + pad_shape = (max_len - local_tensor.size(0),) + tuple(local_tensor.shape[1:]) |
| 316 | + pad = torch.zeros(pad_shape, dtype=local_tensor.dtype, device=local_tensor.device) |
| 317 | + local_padded = torch.cat([local_tensor, pad], dim=0) |
| 318 | + else: |
| 319 | + local_padded = local_tensor |
| 320 | + gathered = [torch.empty_like(local_padded) for _ in range(ep_size)] |
| 321 | + dist.all_gather(gathered, local_padded, group=ep_group) |
| 322 | + gathered = [g[:n] for g, n in zip(gathered, gathered_lens)] |
| 323 | + return torch.cat(gathered, dim=0) |
| 324 | + |
| 325 | + x = _all_gather_dim0_var(x, differentiable=True) |
| 326 | + weights = _all_gather_dim0_var(weights.float(), differentiable=False) |
| 327 | + indices = _all_gather_dim0_var(indices, differentiable=False) |
| 328 | + token_mask = _all_gather_dim0_var(token_mask, differentiable=False) |
268 | 329 |
|
269 | 330 | n_local_experts = self.n_routed_experts // ep_size |
270 | 331 | experts_start_idx = ep_rank * n_local_experts |
@@ -294,9 +355,15 @@ def forward( |
294 | 355 | experts_end_idx, |
295 | 356 | ) |
296 | 357 |
|
| 358 | + # Gradient anchor |
| 359 | + if ep_size > 1: |
| 360 | + y = y + (x * 0.0) |
| 361 | + |
| 362 | + # Variable-length reduce: all_reduce + narrow to original per-rank token boundaries |
297 | 363 | if ep_size > 1: |
298 | | - y = DTensor.from_local(y, device_mesh=ep_mesh, placements=[Partial()]) |
299 | | - y = y.redistribute(placements=[Shard(0)]).to_local() |
| 364 | + y = dist_nn_f.all_reduce(y, op=dist.ReduceOp.SUM, group=ep_group) |
| 365 | + start = sum(gathered_lens[:ep_rank]) |
| 366 | + y = y.narrow(0, start, local_num_tokens).contiguous() |
300 | 367 |
|
301 | 368 | return y.to(input_dtype) |
302 | 369 |
|
|
0 commit comments