Skip to content

Commit 85b8717

Browse files
[Training] [2/n] add bwd for all2all and all_gather (#439)
1 parent 657fd74 commit 85b8717

File tree

4 files changed

+208
-112
lines changed

4 files changed

+208
-112
lines changed

fastvideo/v1/distributed/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
cleanup_dist_env_and_memory, get_sequence_model_parallel_rank,
66
get_sequence_model_parallel_world_size, get_tensor_model_parallel_rank,
77
get_tensor_model_parallel_world_size, get_world_group,
8-
init_distributed_environment, initialize_model_parallel)
8+
init_distributed_environment, initialize_model_parallel,
9+
model_parallel_is_initialized)
910
from fastvideo.v1.distributed.utils import *
1011

1112
__all__ = [
@@ -17,4 +18,5 @@
1718
"get_tensor_model_parallel_world_size",
1819
"cleanup_dist_env_and_memory",
1920
"get_world_group",
21+
"model_parallel_is_initialized",
2022
]

fastvideo/v1/distributed/device_communicators/base_device_communicator.py

Lines changed: 187 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,182 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py
33

4-
from typing import Optional
4+
from typing import Any, Optional, Tuple
55

66
import torch
77
import torch.distributed as dist
8-
from torch.distributed import ProcessGroup
8+
from torch import Tensor
9+
from torch.distributed import ProcessGroup, ReduceOp
10+
11+
12+
class DistributedAutograd:
13+
"""Collection of autograd functions for distributed operations.
14+
15+
This class provides custom autograd functions for distributed operations like all_reduce,
16+
all_gather, and all_to_all. Each operation is implemented as a static inner class with
17+
proper forward and backward implementations.
18+
"""
19+
20+
class AllReduce(torch.autograd.Function):
21+
"""Differentiable all_reduce operation.
22+
23+
The gradient of all_reduce is another all_reduce operation since the operation
24+
combines values from all ranks equally.
25+
"""
26+
27+
@staticmethod
28+
def forward(ctx: Any,
29+
group: ProcessGroup,
30+
input_: Tensor,
31+
op: Optional[dist.ReduceOp] = None) -> Tensor:
32+
ctx.group = group
33+
ctx.op = op
34+
output = input_.clone()
35+
dist.all_reduce(output, group=group, op=op)
36+
return output
37+
38+
@staticmethod
39+
def backward(ctx: Any,
40+
grad_output: Tensor) -> Tuple[None, Tensor, None]:
41+
grad_output = grad_output.clone()
42+
dist.all_reduce(grad_output, group=ctx.group, op=ctx.op)
43+
return None, grad_output, None
44+
45+
class AllGather(torch.autograd.Function):
46+
"""Differentiable all_gather operation.
47+
48+
The operation gathers tensors from all ranks and concatenates them along a specified dimension.
49+
The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.
50+
"""
51+
52+
@staticmethod
53+
def forward(ctx: Any, group: ProcessGroup, input_: Tensor,
54+
world_size: int, dim: int) -> Tensor:
55+
ctx.group = group
56+
ctx.world_size = world_size
57+
ctx.dim = dim
58+
ctx.input_shape = input_.shape
59+
60+
input_size = input_.size()
61+
output_size = (input_size[0] * world_size, ) + input_size[1:]
62+
output_tensor = torch.empty(output_size,
63+
dtype=input_.dtype,
64+
device=input_.device)
65+
66+
dist.all_gather_into_tensor(output_tensor, input_, group=group)
67+
68+
output_tensor = output_tensor.reshape((world_size, ) + input_size)
69+
output_tensor = output_tensor.movedim(0, dim)
70+
output_tensor = output_tensor.reshape(input_size[:dim] +
71+
(world_size *
72+
input_size[dim], ) +
73+
input_size[dim + 1:])
74+
return output_tensor
75+
76+
@staticmethod
77+
def backward(ctx: Any,
78+
grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
79+
# Split the gradient tensor along the gathered dimension
80+
dim_size = grad_output.size(ctx.dim) // ctx.world_size
81+
grad_chunks = grad_output.reshape(grad_output.shape[:ctx.dim] +
82+
(ctx.world_size, dim_size) +
83+
grad_output.shape[ctx.dim + 1:])
84+
grad_chunks = grad_chunks.movedim(ctx.dim, 0)
85+
86+
# Each rank only needs its corresponding gradient
87+
grad_input = torch.empty(ctx.input_shape,
88+
dtype=grad_output.dtype,
89+
device=grad_output.device)
90+
dist.reduce_scatter_tensor(grad_input,
91+
grad_chunks.contiguous(),
92+
group=ctx.group)
93+
94+
return None, grad_input, None, None
95+
96+
class AllToAll4D(torch.autograd.Function):
97+
"""Differentiable all_to_all operation specialized for 4D tensors.
98+
99+
This operation is particularly useful for attention operations where we need to
100+
redistribute data across ranks for efficient parallel processing.
101+
102+
The operation supports two modes:
103+
1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads
104+
2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions
105+
"""
106+
107+
@staticmethod
108+
def forward(ctx: Any, group: ProcessGroup, input_: Tensor,
109+
world_size: int, scatter_dim: int,
110+
gather_dim: int) -> Tensor:
111+
ctx.group = group
112+
ctx.world_size = world_size
113+
ctx.scatter_dim = scatter_dim
114+
ctx.gather_dim = gather_dim
115+
116+
if world_size == 1:
117+
return input_
118+
119+
assert input_.dim(
120+
) == 4, f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}"
121+
122+
if scatter_dim == 2 and gather_dim == 1:
123+
bs, shard_seqlen, hc, hs = input_.shape
124+
seqlen = shard_seqlen * world_size
125+
shard_hc = hc // world_size
126+
127+
input_t = input_.reshape(bs, shard_seqlen, world_size, shard_hc,
128+
hs).transpose(0, 2).contiguous()
129+
output = torch.empty_like(input_t)
130+
131+
dist.all_to_all_single(output, input_t, group=group)
132+
133+
output = output.reshape(seqlen, bs, shard_hc,
134+
hs).transpose(0, 1).contiguous()
135+
output = output.reshape(bs, seqlen, shard_hc, hs)
136+
137+
return output
138+
elif scatter_dim == 1 and gather_dim == 2:
139+
bs, seqlen, shard_hc, hs = input_.shape
140+
hc = shard_hc * world_size
141+
shard_seqlen = seqlen // world_size
142+
143+
input_t = input_.reshape(bs, world_size, shard_seqlen, shard_hc,
144+
hs)
145+
input_t = input_t.transpose(0, 3).transpose(0, 1).contiguous()
146+
input_t = input_t.reshape(world_size, shard_hc, shard_seqlen,
147+
bs, hs)
148+
149+
output = torch.empty_like(input_t)
150+
dist.all_to_all_single(output, input_t, group=group)
151+
152+
output = output.reshape(hc, shard_seqlen, bs, hs)
153+
output = output.transpose(0, 2).contiguous()
154+
output = output.reshape(bs, shard_seqlen, hc, hs)
155+
156+
return output
157+
else:
158+
raise RuntimeError(
159+
f"Invalid scatter_dim={scatter_dim}, gather_dim={gather_dim}. "
160+
f"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported."
161+
)
162+
163+
@staticmethod
164+
def backward(
165+
ctx: Any,
166+
grad_output: Tensor) -> Tuple[None, Tensor, None, None, None]:
167+
if ctx.world_size == 1:
168+
return None, grad_output, None, None, None
169+
170+
# For backward pass, we swap scatter_dim and gather_dim
171+
output = DistributedAutograd.AllToAll4D.apply(
172+
ctx.group, grad_output, ctx.world_size, ctx.gather_dim,
173+
ctx.scatter_dim)
174+
return None, output, None, None, None
9175

10176

11177
class DeviceCommunicatorBase:
12178
"""
13-
Base class for device-specific communicator.
179+
Base class for device-specific communicator with autograd support.
14180
It can use the `cpu_group` to initialize the communicator.
15181
If the device has PyTorch integration (PyTorch can recognize its
16182
communication backend), the `device_group` will also be given.
@@ -33,35 +199,28 @@ def __init__(self,
33199
self.rank_in_group = dist.get_group_rank(self.cpu_group,
34200
self.global_rank)
35201

36-
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
37-
dist.all_reduce(input_, group=self.device_group)
38-
return input_
202+
def all_reduce(self,
203+
input_: torch.Tensor,
204+
op: Optional[dist.ReduceOp] = ReduceOp.SUM) -> torch.Tensor:
205+
"""Performs an all_reduce operation with gradient support."""
206+
return DistributedAutograd.AllReduce.apply(self.device_group, input_,
207+
op)
39208

40209
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
210+
"""Performs an all_gather operation with gradient support."""
41211
if dim < 0:
42-
# Convert negative dim to positive.
43212
dim += input_.dim()
44-
input_size = input_.size()
45-
# NOTE: we have to use concat-style all-gather here,
46-
# stack-style all-gather has compatibility issues with
47-
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
48-
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
49-
# Allocate output tensor.
50-
output_tensor = torch.empty(output_size,
51-
dtype=input_.dtype,
52-
device=input_.device)
53-
# All-gather.
54-
dist.all_gather_into_tensor(output_tensor,
55-
input_,
56-
group=self.device_group)
57-
# Reshape
58-
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
59-
output_tensor = output_tensor.movedim(0, dim)
60-
output_tensor = output_tensor.reshape(input_size[:dim] +
61-
(self.world_size *
62-
input_size[dim], ) +
63-
input_size[dim + 1:])
64-
return output_tensor
213+
return DistributedAutograd.AllGather.apply(self.device_group, input_,
214+
self.world_size, dim)
215+
216+
def all_to_all_4D(self,
217+
input_: torch.Tensor,
218+
scatter_dim: int = 2,
219+
gather_dim: int = 1) -> torch.Tensor:
220+
"""Performs a 4D all-to-all operation with gradient support."""
221+
return DistributedAutograd.AllToAll4D.apply(self.device_group, input_,
222+
self.world_size,
223+
scatter_dim, gather_dim)
65224

66225
def gather(self,
67226
input_: torch.Tensor,
@@ -95,81 +254,6 @@ def gather(self,
95254
output_tensor = None
96255
return output_tensor
97256

98-
def all_to_all_4D(self,
99-
input_: torch.Tensor,
100-
scatter_dim: int = 2,
101-
gather_dim: int = 1) -> torch.Tensor:
102-
"""Specialized all-to-all operation for 4D tensors (e.g., for QKV matrices).
103-
104-
Args:
105-
input_ (torch.Tensor): 4D input tensor to be scattered and gathered.
106-
scatter_dim (int, optional): Dimension along which to scatter. Defaults to 2.
107-
gather_dim (int, optional): Dimension along which to gather. Defaults to 1.
108-
109-
Returns:
110-
torch.Tensor: Output tensor after all-to-all operation.
111-
"""
112-
# Bypass the function if we are using only 1 GPU.
113-
if self.world_size == 1:
114-
return input_
115-
116-
assert input_.dim(
117-
) == 4, f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}"
118-
119-
if scatter_dim == 2 and gather_dim == 1:
120-
# input: (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
121-
bs, shard_seqlen, hc, hs = input_.shape
122-
seqlen = shard_seqlen * self.world_size
123-
shard_hc = hc // self.world_size
124-
125-
# Reshape and transpose for scattering
126-
input_t = (input_.reshape(bs, shard_seqlen, self.world_size,
127-
shard_hc, hs).transpose(0,
128-
2).contiguous())
129-
130-
output = torch.empty_like(input_t)
131-
132-
torch.distributed.all_to_all_single(output,
133-
input_t,
134-
group=self.device_group)
135-
torch.cuda.synchronize()
136-
137-
# Reshape and transpose back
138-
output = output.reshape(seqlen, bs, shard_hc,
139-
hs).transpose(0, 1).contiguous().reshape(
140-
bs, seqlen, shard_hc, hs)
141-
142-
return output
143-
144-
elif scatter_dim == 1 and gather_dim == 2:
145-
# input: (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
146-
bs, seqlen, shard_hc, hs = input_.shape
147-
hc = shard_hc * self.world_size
148-
shard_seqlen = seqlen // self.world_size
149-
150-
# Reshape and transpose for scattering
151-
input_t = (input_.reshape(bs, self.world_size, shard_seqlen,
152-
shard_hc, hs).transpose(0, 3).transpose(
153-
0, 1).contiguous().reshape(
154-
self.world_size, shard_hc,
155-
shard_seqlen, bs, hs))
156-
output = torch.empty_like(input_t)
157-
158-
torch.distributed.all_to_all_single(output,
159-
input_t,
160-
group=self.device_group)
161-
torch.cuda.synchronize()
162-
163-
# Reshape and transpose back
164-
output = output.reshape(hc, shard_seqlen, bs,
165-
hs).transpose(0, 2).contiguous().reshape(
166-
bs, shard_seqlen, hc, hs)
167-
168-
return output
169-
else:
170-
raise RuntimeError(
171-
"scatter_dim must be 1 or 2 and gather_dim must be 1 or 2")
172-
173257
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
174258
"""Sends a tensor to the destination rank in a non-blocking way"""
175259
"""NOTE: `dst` is the local rank of the destination rank."""

fastvideo/v1/distributed/device_communicators/cuda_communicator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ def __init__(self,
2929
device=self.device,
3030
)
3131

32-
def all_reduce(self, input_):
32+
def all_reduce(self,
33+
input_,
34+
op: Optional[torch.distributed.ReduceOp] = None):
3335
pynccl_comm = self.pynccl_comm
3436
assert pynccl_comm is not None
35-
out = pynccl_comm.all_reduce(input_)
37+
out = pynccl_comm.all_reduce(input_, op=op)
3638
if out is None:
3739
# fall back to the default all-reduce using PyTorch.
3840
# this usually happens during testing.
3941
# when we run the model, allreduce only happens for the TP
4042
# group, where we always have either custom allreduce or pynccl.
4143
out = input_.clone()
42-
torch.distributed.all_reduce(out, group=self.device_group)
44+
torch.distributed.all_reduce(out, group=self.device_group, op=op)
4345
return out
4446

4547
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:

0 commit comments

Comments
 (0)