3333from torch .distributed import ProcessGroup , ReduceOp , TCPStore
3434
3535from lightllm .distributed .pynccl_wrapper import (
36- NCCLLibrary , buffer_type , cudaStream_t , ncclComm_t , ncclDataTypeEnum ,
37- ncclRedOpTypeEnum , ncclUniqueId )
36+ NCCLLibrary ,
37+ buffer_type ,
38+ cudaStream_t ,
39+ ncclComm_t ,
40+ ncclDataTypeEnum ,
41+ ncclRedOpTypeEnum ,
42+ ncclUniqueId ,
43+ )
3844
3945logger = logging .getLogger (__name__ )
4046
4147_current_stream = None
4248
49+
4350def current_stream () -> torch .cuda .Stream :
4451 global _current_stream
4552 if _current_stream is None :
4653 _current_stream = torch .cuda .current_stream ()
4754 return _current_stream
4855
56+
4957@dataclasses .dataclass
5058class StatelessP2PProcessGroup :
5159 """A dataclass to hold a metadata store, and the rank, world_size of the
@@ -94,18 +102,13 @@ def expire_data(self):
94102
95103 def recv_obj (self ) -> Any :
96104 """Receive an object from a source rank."""
97- obj = pickle .loads (
98- self .store .get (
99- f"send_to/{ self .dest_id } /{ self .recv_src_counter } " ))
105+ obj = pickle .loads (self .store .get (f"send_to/{ self .dest_id } /{ self .recv_src_counter } " ))
100106 self .recv_src_counter += 1
101107 return obj
102108
103109 @staticmethod
104110 def create (
105- src_id : int ,
106- dest_id : int ,
107- is_server : bool ,
108- store : torch ._C ._distributed_c10d .Store
111+ src_id : int , dest_id : int , is_server : bool , store : torch ._C ._distributed_c10d .Store
109112 ) -> "StatelessP2PProcessGroup" :
110113 """A replacement for `torch.distributed.init_process_group` that does not
111114 pollute the global state.
@@ -121,12 +124,11 @@ def create(
121124 used for exchanging metadata. With this function, process A and process B
122125 can call `StatelessProcessGroup.create` to form a group, and then process A, B,
123126 C, and D can call `StatelessProcessGroup.create` to form another group.
124- """ # noqa
127+ """ # noqa
125128 return StatelessP2PProcessGroup (src_id = src_id , dest_id = dest_id , is_server = is_server , store = store )
126129
127130
128131class PyNcclCommunicator :
129-
130132 def __init__ (
131133 self ,
132134 group : Union [ProcessGroup , StatelessP2PProcessGroup ],
@@ -146,8 +148,9 @@ def __init__(
146148 """
147149 if not isinstance (group , StatelessP2PProcessGroup ):
148150 assert dist .is_initialized ()
149- assert dist .get_backend (group ) != dist .Backend .NCCL , (
150- "PyNcclCommunicator should be attached to a non-NCCL group." )
151+ assert (
152+ dist .get_backend (group ) != dist .Backend .NCCL
153+ ), "PyNcclCommunicator should be attached to a non-NCCL group."
151154 # note: this rank is the rank in the group
152155 self .rank = dist .get_rank (group )
153156 self .world_size = dist .get_world_size (group )
@@ -207,8 +210,7 @@ def __init__(
207210 # `torch.cuda.device` is a context manager that changes the
208211 # current cuda device to the specified one
209212 with torch .cuda .device (device ):
210- self .comm : ncclComm_t = self .nccl .ncclCommInitRank (
211- self .world_size , self .unique_id , self .rank )
213+ self .comm : ncclComm_t = self .nccl .ncclCommInitRank (self .world_size , self .unique_id , self .rank )
212214
213215 stream = current_stream ()
214216 # A small all_reduce for warmup.
@@ -220,103 +222,120 @@ def __init__(
220222 def destroy (self ):
221223 self .nccl .ncclCommDestroy (self .comm )
222224
223- def all_reduce (self ,
224- in_tensor : torch .Tensor ,
225- op : ReduceOp = ReduceOp .SUM ,
226- stream = None ) -> torch .Tensor :
225+ def all_reduce (self , in_tensor : torch .Tensor , op : ReduceOp = ReduceOp .SUM , stream = None ) -> torch .Tensor :
227226 if self .disabled :
228227 return None
229228 # nccl communicator created on a specific device
230229 # will only work on tensors on the same device
231230 # otherwise it will cause "illegal memory access"
232231 assert in_tensor .device == self .device , (
233232 f"this nccl communicator is created to work on { self .device } , "
234- f"but the input tensor is on { in_tensor .device } " )
233+ f"but the input tensor is on { in_tensor .device } "
234+ )
235235
236236 out_tensor = torch .empty_like (in_tensor )
237237
238238 if stream is None :
239239 stream = current_stream ()
240- self .nccl .ncclAllReduce (buffer_type (in_tensor .data_ptr ()),
241- buffer_type (out_tensor .data_ptr ()),
242- in_tensor .numel (),
243- ncclDataTypeEnum .from_torch (in_tensor .dtype ),
244- ncclRedOpTypeEnum .from_torch (op ), self .comm ,
245- cudaStream_t (stream .cuda_stream ))
240+ self .nccl .ncclAllReduce (
241+ buffer_type (in_tensor .data_ptr ()),
242+ buffer_type (out_tensor .data_ptr ()),
243+ in_tensor .numel (),
244+ ncclDataTypeEnum .from_torch (in_tensor .dtype ),
245+ ncclRedOpTypeEnum .from_torch (op ),
246+ self .comm ,
247+ cudaStream_t (stream .cuda_stream ),
248+ )
246249 return out_tensor
247250
248- def all_gather (self ,
249- output_tensor : torch .Tensor ,
250- input_tensor : torch .Tensor ,
251- stream = None ):
251+ def all_gather (self , output_tensor : torch .Tensor , input_tensor : torch .Tensor , stream = None ):
252252 if self .disabled :
253253 return
254254 # nccl communicator created on a specific device
255255 # will only work on tensors on the same device
256256 # otherwise it will cause "illegal memory access"
257257 assert input_tensor .device == self .device , (
258258 f"this nccl communicator is created to work on { self .device } , "
259- f"but the input tensor is on { input_tensor .device } " )
259+ f"but the input tensor is on { input_tensor .device } "
260+ )
260261 if stream is None :
261262 stream = current_stream ()
262263 self .nccl .ncclAllGather (
263264 buffer_type (input_tensor .data_ptr ()),
264- buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
265- ncclDataTypeEnum .from_torch (input_tensor .dtype ), self .comm ,
266- cudaStream_t (stream .cuda_stream ))
267-
268- def reduce_scatter (self ,
269- output_tensor : torch .Tensor ,
270- input_tensor : torch .Tensor ,
271- op : ReduceOp = ReduceOp .SUM ,
272- stream = None ):
265+ buffer_type (output_tensor .data_ptr ()),
266+ input_tensor .numel (),
267+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
268+ self .comm ,
269+ cudaStream_t (stream .cuda_stream ),
270+ )
271+
272+ def reduce_scatter (
273+ self , output_tensor : torch .Tensor , input_tensor : torch .Tensor , op : ReduceOp = ReduceOp .SUM , stream = None
274+ ):
273275 if self .disabled :
274276 return
275277 # nccl communicator created on a specific device
276278 # will only work on tensors on the same device
277279 # otherwise it will cause "illegal memory access"
278280 assert input_tensor .device == self .device , (
279281 f"this nccl communicator is created to work on { self .device } , "
280- f"but the input tensor is on { input_tensor .device } " )
282+ f"but the input tensor is on { input_tensor .device } "
283+ )
281284 if stream is None :
282285 stream = current_stream ()
283286 self .nccl .ncclReduceScatter (
284287 buffer_type (input_tensor .data_ptr ()),
285- buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
288+ buffer_type (output_tensor .data_ptr ()),
289+ output_tensor .numel (),
286290 ncclDataTypeEnum .from_torch (input_tensor .dtype ),
287- ncclRedOpTypeEnum .from_torch (op ), self .comm ,
288- cudaStream_t (stream .cuda_stream ))
291+ ncclRedOpTypeEnum .from_torch (op ),
292+ self .comm ,
293+ cudaStream_t (stream .cuda_stream ),
294+ )
289295
290296 def send (self , tensor : torch .Tensor , dst : int , stream = None ):
291297 if self .disabled :
292298 return
293299 assert tensor .device == self .device , (
294300 f"this nccl communicator is created to work on { self .device } , "
295- f"but the input tensor is on { tensor .device } " )
301+ f"but the input tensor is on { tensor .device } "
302+ )
296303 if stream is None :
297304 stream = current_stream ()
298- self .nccl .ncclSend (buffer_type (tensor .data_ptr ()), tensor .numel (),
299- ncclDataTypeEnum .from_torch (tensor .dtype ), dst ,
300- self .comm , cudaStream_t (stream .cuda_stream ))
305+ self .nccl .ncclSend (
306+ buffer_type (tensor .data_ptr ()),
307+ tensor .numel (),
308+ ncclDataTypeEnum .from_torch (tensor .dtype ),
309+ dst ,
310+ self .comm ,
311+ cudaStream_t (stream .cuda_stream ),
312+ )
301313
302314 def recv (self , tensor : torch .Tensor , src : int , stream = None ):
303315 if self .disabled :
304316 return
305317 assert tensor .device == self .device , (
306318 f"this nccl communicator is created to work on { self .device } , "
307- f"but the input tensor is on { tensor .device } " )
319+ f"but the input tensor is on { tensor .device } "
320+ )
308321 if stream is None :
309322 stream = current_stream ()
310- self .nccl .ncclRecv (buffer_type (tensor .data_ptr ()), tensor .numel (),
311- ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
312- self .comm , cudaStream_t (stream .cuda_stream ))
323+ self .nccl .ncclRecv (
324+ buffer_type (tensor .data_ptr ()),
325+ tensor .numel (),
326+ ncclDataTypeEnum .from_torch (tensor .dtype ),
327+ src ,
328+ self .comm ,
329+ cudaStream_t (stream .cuda_stream ),
330+ )
313331
314332 def broadcast (self , tensor : torch .Tensor , src : int , stream = None ):
315333 if self .disabled :
316334 return
317335 assert tensor .device == self .device , (
318336 f"this nccl communicator is created to work on { self .device } , "
319- f"but the input tensor is on { tensor .device } " )
337+ f"but the input tensor is on { tensor .device } "
338+ )
320339 if stream is None :
321340 stream = current_stream ()
322341 if src == self .rank :
@@ -326,7 +345,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
326345 else :
327346 sendbuff = buffer_type ()
328347 recvbuff = buffer_type (tensor .data_ptr ())
329- self .nccl .ncclBroadcast (sendbuff , recvbuff , tensor .numel (),
330- ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
331- self .comm , cudaStream_t (stream .cuda_stream ))
332-
348+ self .nccl .ncclBroadcast (
349+ sendbuff ,
350+ recvbuff ,
351+ tensor .numel (),
352+ ncclDataTypeEnum .from_torch (tensor .dtype ),
353+ src ,
354+ self .comm ,
355+ cudaStream_t (stream .cuda_stream ),
356+ )
0 commit comments