11__all__ = [
2+ "_prepare_nccl_allgather_inputs" ,
3+ "_unroll_nccl_allgather_recv" ,
24 "initialize_nccl_comm" ,
35 "nccl_split" ,
46 "nccl_allgather" ,
57 "nccl_allreduce" ,
68 "nccl_bcast" ,
79 "nccl_asarray" ,
810 "nccl_send" ,
9- "nccl_recv"
11+ "nccl_recv" ,
1012]
1113
1214from enum import IntEnum
15+ from typing import Tuple
1316from mpi4py import MPI
1417import os
1518import numpy as np
1619import cupy as cp
1720import cupy .cuda .nccl as nccl
1821
22+
1923cupy_to_nccl_dtype = {
2024 "float32" : nccl .NCCL_FLOAT32 ,
2125 "float64" : nccl .NCCL_FLOAT64 ,
2529 "int8" : nccl .NCCL_INT8 ,
2630 "uint32" : nccl .NCCL_UINT32 ,
2731 "uint64" : nccl .NCCL_UINT64 ,
32+ # sending complex array as float with 2x size
33+ "complex64" : nccl .NCCL_FLOAT32 ,
34+ "complex128" : nccl .NCCL_FLOAT64 ,
2835}
2936
3037
@@ -35,6 +42,106 @@ class NcclOp(IntEnum):
3542 MIN = nccl .NCCL_MIN
3643
3744
45+ def _nccl_buf_size (buf , count = None ):
46+ """ Get an appropriate buffer size according to the dtype of buf
47+
48+ Parameters
49+ ----------
50+ buf : :obj:`cupy.ndarray` or array-like
51+ The data buffer from the local GPU to be sent.
52+
53+ count : :obj:`int`, optional
54+ Number of elements to send from `buf`, if not sending the every element in `buf`.
55+ Returns:
56+ -------
57+ :obj:`int`
58+ An appropriate number of elements to send from `send_buf` for NCCL communication.
59+ """
60+ if buf .dtype in ['complex64' , 'complex128' ]:
61+ return 2 * count if count else 2 * buf .size
62+ else :
63+ return count if count else buf .size
64+
65+
66+ def _prepare_nccl_allgather_inputs (send_buf , send_buf_shapes ) -> Tuple [cp .ndarray , cp .ndarray ]:
67+ r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
68+
69+ NCCL's allGather requires the sending buffer to have the same size for every device.
70+ Therefore, padding is required when the array is not evenly partitioned across
71+ all the ranks. The padding is applied such that the each dimension of the sending buffers
72+ is equal to the max size of that dimension across all ranks.
73+
74+ Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
75+
76+ Parameters
77+ ----------
78+ send_buf : :obj:`cupy.ndarray` or array-like
79+ The data buffer from the local GPU to be sent for allgather.
80+ send_buf_shapes: :obj:`list`
81+ A list of shapes for each GPU send_buf (used to calculate padding size)
82+
83+ Returns
84+ -------
85+ send_buf: :obj:`cupy.ndarray`
86+ A buffer containing the data and padded elements to be sent by this rank.
87+ recv_buf : :obj:`cupy.ndarray`
88+ An empty, padded buffer to gather data from all GPUs.
89+ """
90+ sizes_each_dim = list (zip (* send_buf_shapes ))
91+ send_shape = tuple (map (max , sizes_each_dim ))
92+ pad_size = [
93+ (0 , s_shape - l_shape ) for s_shape , l_shape in zip (send_shape , send_buf .shape )
94+ ]
95+
96+ send_buf = cp .pad (
97+ send_buf , pad_size , mode = "constant" , constant_values = 0
98+ )
99+
100+ # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
101+ ndev = len (send_buf_shapes )
102+ recv_buf = cp .zeros (ndev * send_buf .size , dtype = send_buf .dtype )
103+
104+ return send_buf , recv_buf
105+
106+
107+ def _unroll_nccl_allgather_recv (recv_buf , padded_send_buf_shape , send_buf_shapes ) -> list :
108+ """Unrolll recv_buf after NCCL allgather (nccl_allgather)
109+
110+ Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
111+ Each GPU may send array with a different shape, so the return type has to be a list of array
112+ instead of the concatenated array.
113+
114+ Parameters
115+ ----------
116+ recv_buf: :obj:`cupy.ndarray` or array-like
117+ The data buffer returned from nccl_allgather call
118+ padded_send_buf_shape: :obj:`tuple`:int
119+ The size of send_buf after padding used in nccl_allgather
120+ send_buf_shapes: :obj:`list`
121+ A list of original shapes for each GPU send_buf prior to padding
122+
123+ Returns
124+ -------
125+ chunks: :obj:`list`
126+ A list of `cupy.ndarray` from each GPU with the padded element removed
127+ """
128+
129+ ndev = len (send_buf_shapes )
130+ # extract an individual array from each device
131+ chunk_size = np .prod (padded_send_buf_shape )
132+ chunks = [
133+ recv_buf [i * chunk_size :(i + 1 ) * chunk_size ] for i in range (ndev )
134+ ]
135+
136+ # Remove padding from each array: the padded value may appear somewhere
137+ # in the middle of the flat array and thus the reshape and slicing for each dimension is required
138+ for i in range (ndev ):
139+ slicing = tuple (slice (0 , end ) for end in send_buf_shapes [i ])
140+ chunks [i ] = chunks [i ].reshape (padded_send_buf_shape )[slicing ]
141+
142+ return chunks
143+
144+
38145def mpi_op_to_nccl (mpi_op ) -> NcclOp :
39146 """ Map MPI reduction operation to NCCL equivalent
40147
@@ -155,7 +262,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
155262 nccl_comm .allGather (
156263 send_buf .data .ptr ,
157264 recv_buf .data .ptr ,
158- send_buf . size ,
265+ _nccl_buf_size ( send_buf ) ,
159266 cupy_to_nccl_dtype [str (send_buf .dtype )],
160267 cp .cuda .Stream .null .ptr ,
161268 )
@@ -193,7 +300,7 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) ->
193300 nccl_comm .allReduce (
194301 send_buf .data .ptr ,
195302 recv_buf .data .ptr ,
196- send_buf . size ,
303+ _nccl_buf_size ( send_buf ) ,
197304 cupy_to_nccl_dtype [str (send_buf .dtype )],
198305 mpi_op_to_nccl (op ),
199306 cp .cuda .Stream .null .ptr ,
@@ -220,7 +327,7 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
220327 local_array [index ] = value
221328 nccl_comm .bcast (
222329 local_array [index ].data .ptr ,
223- local_array [index ]. size ,
330+ _nccl_buf_size ( local_array [index ]) ,
224331 cupy_to_nccl_dtype [str (local_array [index ].dtype )],
225332 0 ,
226333 cp .cuda .Stream .null .ptr ,
@@ -247,41 +354,12 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
247354 -------
248355 final_array : :obj:`cupy.ndarray`
249356 Global array gathered from all GPUs and concatenated along `axis`.
250-
251- Notes
252- -----
253- NCCL's allGather requires the sending buffer to have the same size for every device.
254- Therefore, the padding is required when the array is not evenly partitioned across
255- all the ranks. The padding is applied such that the sending buffer has the size of
256- each dimension corresponding to the max possible size of that dimension.
257357 """
258- sizes_each_dim = list (zip (* local_shapes ))
259-
260- send_shape = tuple (map (max , sizes_each_dim ))
261- pad_size = [
262- (0 , s_shape - l_shape ) for s_shape , l_shape in zip (send_shape , local_array .shape )
263- ]
264358
265- send_buf = cp .pad (
266- local_array , pad_size , mode = "constant" , constant_values = 0
267- )
268-
269- # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
270- ndev = len (local_shapes )
271- recv_buf = cp .zeros (ndev * send_buf .size , dtype = send_buf .dtype )
359+ send_buf , recv_buf = _prepare_nccl_allgather_inputs (local_array , local_shapes )
272360 nccl_allgather (nccl_comm , send_buf , recv_buf )
361+ chunks = _unroll_nccl_allgather_recv (recv_buf , send_buf .shape , local_shapes )
273362
274- # extract an individual array from each device
275- chunk_size = np .prod (send_shape )
276- chunks = [
277- recv_buf [i * chunk_size :(i + 1 ) * chunk_size ] for i in range (ndev )
278- ]
279-
280- # Remove padding from each array: the padded value may appear somewhere
281- # in the middle of the flat array and thus the reshape and slicing for each dimension is required
282- for i in range (ndev ):
283- slicing = tuple (slice (0 , end ) for end in local_shapes [i ])
284- chunks [i ] = chunks [i ].reshape (send_shape )[slicing ]
285363 # combine back to single global array
286364 return cp .concatenate (chunks , axis = axis )
287365
@@ -302,7 +380,7 @@ def nccl_send(nccl_comm, send_buf, dest, count):
302380 Number of elements to send from `send_buf`.
303381 """
304382 nccl_comm .send (send_buf .data .ptr ,
305- count ,
383+ _nccl_buf_size ( send_buf , count ) ,
306384 cupy_to_nccl_dtype [str (send_buf .dtype )],
307385 dest ,
308386 cp .cuda .Stream .null .ptr
@@ -325,7 +403,7 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None):
325403 Number of elements to receive.
326404 """
327405 nccl_comm .recv (recv_buf .data .ptr ,
328- count ,
406+ _nccl_buf_size ( recv_buf , count ) ,
329407 cupy_to_nccl_dtype [str (recv_buf .dtype )],
330408 source ,
331409 cp .cuda .Stream .null .ptr
0 commit comments