11# SPDX-License-Identifier: Apache-2.0
22# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py
33
4- from typing import Optional
4+ from typing import Any , Optional , Tuple
55
66import torch
77import torch .distributed as dist
8- from torch .distributed import ProcessGroup
8+ from torch import Tensor
9+ from torch .distributed import ProcessGroup , ReduceOp
10+
11+
12+ class DistributedAutograd :
13+ """Collection of autograd functions for distributed operations.
14+
15+ This class provides custom autograd functions for distributed operations like all_reduce,
16+ all_gather, and all_to_all. Each operation is implemented as a static inner class with
17+ proper forward and backward implementations.
18+ """
19+
20+ class AllReduce (torch .autograd .Function ):
21+ """Differentiable all_reduce operation.
22+
23+ The gradient of all_reduce is another all_reduce operation since the operation
24+ combines values from all ranks equally.
25+ """
26+
27+ @staticmethod
28+ def forward (ctx : Any ,
29+ group : ProcessGroup ,
30+ input_ : Tensor ,
31+ op : Optional [dist .ReduceOp ] = None ) -> Tensor :
32+ ctx .group = group
33+ ctx .op = op
34+ output = input_ .clone ()
35+ dist .all_reduce (output , group = group , op = op )
36+ return output
37+
38+ @staticmethod
39+ def backward (ctx : Any ,
40+ grad_output : Tensor ) -> Tuple [None , Tensor , None ]:
41+ grad_output = grad_output .clone ()
42+ dist .all_reduce (grad_output , group = ctx .group , op = ctx .op )
43+ return None , grad_output , None
44+
45+ class AllGather (torch .autograd .Function ):
46+ """Differentiable all_gather operation.
47+
48+ The operation gathers tensors from all ranks and concatenates them along a specified dimension.
49+ The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.
50+ """
51+
52+ @staticmethod
53+ def forward (ctx : Any , group : ProcessGroup , input_ : Tensor ,
54+ world_size : int , dim : int ) -> Tensor :
55+ ctx .group = group
56+ ctx .world_size = world_size
57+ ctx .dim = dim
58+ ctx .input_shape = input_ .shape
59+
60+ input_size = input_ .size ()
61+ output_size = (input_size [0 ] * world_size , ) + input_size [1 :]
62+ output_tensor = torch .empty (output_size ,
63+ dtype = input_ .dtype ,
64+ device = input_ .device )
65+
66+ dist .all_gather_into_tensor (output_tensor , input_ , group = group )
67+
68+ output_tensor = output_tensor .reshape ((world_size , ) + input_size )
69+ output_tensor = output_tensor .movedim (0 , dim )
70+ output_tensor = output_tensor .reshape (input_size [:dim ] +
71+ (world_size *
72+ input_size [dim ], ) +
73+ input_size [dim + 1 :])
74+ return output_tensor
75+
76+ @staticmethod
77+ def backward (ctx : Any ,
78+ grad_output : Tensor ) -> Tuple [None , Tensor , None , None ]:
79+ # Split the gradient tensor along the gathered dimension
80+ dim_size = grad_output .size (ctx .dim ) // ctx .world_size
81+ grad_chunks = grad_output .reshape (grad_output .shape [:ctx .dim ] +
82+ (ctx .world_size , dim_size ) +
83+ grad_output .shape [ctx .dim + 1 :])
84+ grad_chunks = grad_chunks .movedim (ctx .dim , 0 )
85+
86+ # Each rank only needs its corresponding gradient
87+ grad_input = torch .empty (ctx .input_shape ,
88+ dtype = grad_output .dtype ,
89+ device = grad_output .device )
90+ dist .reduce_scatter_tensor (grad_input ,
91+ grad_chunks .contiguous (),
92+ group = ctx .group )
93+
94+ return None , grad_input , None , None
95+
96+ class AllToAll4D (torch .autograd .Function ):
97+ """Differentiable all_to_all operation specialized for 4D tensors.
98+
99+ This operation is particularly useful for attention operations where we need to
100+ redistribute data across ranks for efficient parallel processing.
101+
102+ The operation supports two modes:
103+ 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads
104+ 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions
105+ """
106+
107+ @staticmethod
108+ def forward (ctx : Any , group : ProcessGroup , input_ : Tensor ,
109+ world_size : int , scatter_dim : int ,
110+ gather_dim : int ) -> Tensor :
111+ ctx .group = group
112+ ctx .world_size = world_size
113+ ctx .scatter_dim = scatter_dim
114+ ctx .gather_dim = gather_dim
115+
116+ if world_size == 1 :
117+ return input_
118+
119+ assert input_ .dim (
120+ ) == 4 , f"input must be 4D tensor, got { input_ .dim ()} and shape { input_ .shape } "
121+
122+ if scatter_dim == 2 and gather_dim == 1 :
123+ bs , shard_seqlen , hc , hs = input_ .shape
124+ seqlen = shard_seqlen * world_size
125+ shard_hc = hc // world_size
126+
127+ input_t = input_ .reshape (bs , shard_seqlen , world_size , shard_hc ,
128+ hs ).transpose (0 , 2 ).contiguous ()
129+ output = torch .empty_like (input_t )
130+
131+ dist .all_to_all_single (output , input_t , group = group )
132+
133+ output = output .reshape (seqlen , bs , shard_hc ,
134+ hs ).transpose (0 , 1 ).contiguous ()
135+ output = output .reshape (bs , seqlen , shard_hc , hs )
136+
137+ return output
138+ elif scatter_dim == 1 and gather_dim == 2 :
139+ bs , seqlen , shard_hc , hs = input_ .shape
140+ hc = shard_hc * world_size
141+ shard_seqlen = seqlen // world_size
142+
143+ input_t = input_ .reshape (bs , world_size , shard_seqlen , shard_hc ,
144+ hs )
145+ input_t = input_t .transpose (0 , 3 ).transpose (0 , 1 ).contiguous ()
146+ input_t = input_t .reshape (world_size , shard_hc , shard_seqlen ,
147+ bs , hs )
148+
149+ output = torch .empty_like (input_t )
150+ dist .all_to_all_single (output , input_t , group = group )
151+
152+ output = output .reshape (hc , shard_seqlen , bs , hs )
153+ output = output .transpose (0 , 2 ).contiguous ()
154+ output = output .reshape (bs , shard_seqlen , hc , hs )
155+
156+ return output
157+ else :
158+ raise RuntimeError (
159+ f"Invalid scatter_dim={ scatter_dim } , gather_dim={ gather_dim } . "
160+ f"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported."
161+ )
162+
163+ @staticmethod
164+ def backward (
165+ ctx : Any ,
166+ grad_output : Tensor ) -> Tuple [None , Tensor , None , None , None ]:
167+ if ctx .world_size == 1 :
168+ return None , grad_output , None , None , None
169+
170+ # For backward pass, we swap scatter_dim and gather_dim
171+ output = DistributedAutograd .AllToAll4D .apply (
172+ ctx .group , grad_output , ctx .world_size , ctx .gather_dim ,
173+ ctx .scatter_dim )
174+ return None , output , None , None , None
9175
10176
11177class DeviceCommunicatorBase :
12178 """
13- Base class for device-specific communicator.
179+ Base class for device-specific communicator with autograd support .
14180 It can use the `cpu_group` to initialize the communicator.
15181 If the device has PyTorch integration (PyTorch can recognize its
16182 communication backend), the `device_group` will also be given.
@@ -33,35 +199,28 @@ def __init__(self,
33199 self .rank_in_group = dist .get_group_rank (self .cpu_group ,
34200 self .global_rank )
35201
36- def all_reduce (self , input_ : torch .Tensor ) -> torch .Tensor :
37- dist .all_reduce (input_ , group = self .device_group )
38- return input_
202+ def all_reduce (self ,
203+ input_ : torch .Tensor ,
204+ op : Optional [dist .ReduceOp ] = ReduceOp .SUM ) -> torch .Tensor :
205+ """Performs an all_reduce operation with gradient support."""
206+ return DistributedAutograd .AllReduce .apply (self .device_group , input_ ,
207+ op )
39208
40209 def all_gather (self , input_ : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
210+ """Performs an all_gather operation with gradient support."""
41211 if dim < 0 :
42- # Convert negative dim to positive.
43212 dim += input_ .dim ()
44- input_size = input_ .size ()
45- # NOTE: we have to use concat-style all-gather here,
46- # stack-style all-gather has compatibility issues with
47- # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
48- output_size = (input_size [0 ] * self .world_size , ) + input_size [1 :]
49- # Allocate output tensor.
50- output_tensor = torch .empty (output_size ,
51- dtype = input_ .dtype ,
52- device = input_ .device )
53- # All-gather.
54- dist .all_gather_into_tensor (output_tensor ,
55- input_ ,
56- group = self .device_group )
57- # Reshape
58- output_tensor = output_tensor .reshape ((self .world_size , ) + input_size )
59- output_tensor = output_tensor .movedim (0 , dim )
60- output_tensor = output_tensor .reshape (input_size [:dim ] +
61- (self .world_size *
62- input_size [dim ], ) +
63- input_size [dim + 1 :])
64- return output_tensor
213+ return DistributedAutograd .AllGather .apply (self .device_group , input_ ,
214+ self .world_size , dim )
215+
216+ def all_to_all_4D (self ,
217+ input_ : torch .Tensor ,
218+ scatter_dim : int = 2 ,
219+ gather_dim : int = 1 ) -> torch .Tensor :
220+ """Performs a 4D all-to-all operation with gradient support."""
221+ return DistributedAutograd .AllToAll4D .apply (self .device_group , input_ ,
222+ self .world_size ,
223+ scatter_dim , gather_dim )
65224
66225 def gather (self ,
67226 input_ : torch .Tensor ,
@@ -95,81 +254,6 @@ def gather(self,
95254 output_tensor = None
96255 return output_tensor
97256
98- def all_to_all_4D (self ,
99- input_ : torch .Tensor ,
100- scatter_dim : int = 2 ,
101- gather_dim : int = 1 ) -> torch .Tensor :
102- """Specialized all-to-all operation for 4D tensors (e.g., for QKV matrices).
103-
104- Args:
105- input_ (torch.Tensor): 4D input tensor to be scattered and gathered.
106- scatter_dim (int, optional): Dimension along which to scatter. Defaults to 2.
107- gather_dim (int, optional): Dimension along which to gather. Defaults to 1.
108-
109- Returns:
110- torch.Tensor: Output tensor after all-to-all operation.
111- """
112- # Bypass the function if we are using only 1 GPU.
113- if self .world_size == 1 :
114- return input_
115-
116- assert input_ .dim (
117- ) == 4 , f"input must be 4D tensor, got { input_ .dim ()} and shape { input_ .shape } "
118-
119- if scatter_dim == 2 and gather_dim == 1 :
120- # input: (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
121- bs , shard_seqlen , hc , hs = input_ .shape
122- seqlen = shard_seqlen * self .world_size
123- shard_hc = hc // self .world_size
124-
125- # Reshape and transpose for scattering
126- input_t = (input_ .reshape (bs , shard_seqlen , self .world_size ,
127- shard_hc , hs ).transpose (0 ,
128- 2 ).contiguous ())
129-
130- output = torch .empty_like (input_t )
131-
132- torch .distributed .all_to_all_single (output ,
133- input_t ,
134- group = self .device_group )
135- torch .cuda .synchronize ()
136-
137- # Reshape and transpose back
138- output = output .reshape (seqlen , bs , shard_hc ,
139- hs ).transpose (0 , 1 ).contiguous ().reshape (
140- bs , seqlen , shard_hc , hs )
141-
142- return output
143-
144- elif scatter_dim == 1 and gather_dim == 2 :
145- # input: (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
146- bs , seqlen , shard_hc , hs = input_ .shape
147- hc = shard_hc * self .world_size
148- shard_seqlen = seqlen // self .world_size
149-
150- # Reshape and transpose for scattering
151- input_t = (input_ .reshape (bs , self .world_size , shard_seqlen ,
152- shard_hc , hs ).transpose (0 , 3 ).transpose (
153- 0 , 1 ).contiguous ().reshape (
154- self .world_size , shard_hc ,
155- shard_seqlen , bs , hs ))
156- output = torch .empty_like (input_t )
157-
158- torch .distributed .all_to_all_single (output ,
159- input_t ,
160- group = self .device_group )
161- torch .cuda .synchronize ()
162-
163- # Reshape and transpose back
164- output = output .reshape (hc , shard_seqlen , bs ,
165- hs ).transpose (0 , 2 ).contiguous ().reshape (
166- bs , shard_seqlen , hc , hs )
167-
168- return output
169- else :
170- raise RuntimeError (
171- "scatter_dim must be 1 or 2 and gather_dim must be 1 or 2" )
172-
173257 def send (self , tensor : torch .Tensor , dst : Optional [int ] = None ) -> None :
174258 """Sends a tensor to the destination rank in a non-blocking way"""
175259 """NOTE: `dst` is the local rank of the destination rank."""
0 commit comments