55import numpy as np
66
77
8- def _unroll_allgather_recv (recv_buf , chunk_shape , send_buf_shapes , displs = None ) -> list :
8+ def _unroll_allgather_recv (recv_buf , buffer_chunk_shape , send_buf_shapes , displs = None ) -> list :
99 r"""Unroll recv_buf after Buffered Allgather (MPI and NCCL)
1010
1111 Depending on the provided parameters, the function:
@@ -18,14 +18,14 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
1818 Parameters
1919 ----------
2020 recv_buf: :obj:`cupy.ndarray` or array-like
21- The data buffer returned from nccl_allgather call
21+ The data buffer returned from the allgather call
2222 send_buf_shapes: :obj:`list`
2323 A list of original shapes of each rank's send_buf before any padding.
24- chunk_shape : tuple
25- Shape of each gathered chunk in recv_buf. This must match the shape
26- used to construct the gathered buffer : use the padded send buffer shape
27- when padding is required (e.g., nccl_allgather with padding ), or the original send buffer
28- shape when no padding is used.
24+ buffer_chunk_shape : tuple
25+ Shape of each rank’s data as stored in `` recv_buf`` . This should match
26+ the layout used during allgather : use the padded send buffer shape when
27+ padding is applied (e.g., NCCL ), or the original send buffer shape when
28+ no padding is used.
2929 displs : list, optional
3030 Starting offsets in recv_buf for each rank's data, used when chunks have
3131 variable sizes (e.g., mpi_allgather with displacements).
@@ -44,7 +44,8 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
4444 for i in range (ndev )
4545 ]
4646 else :
47- chunk_size = np .prod (chunk_shape )
47+ # extract an individual array from each device
48+ chunk_size = np .prod (buffer_chunk_shape )
4849 chunks = [
4950 recv_buf [i * chunk_size :(i + 1 ) * chunk_size ]
5051 for i in range (ndev )
@@ -53,5 +54,5 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
5354 # in the middle of the flat array and thus the reshape and slicing for each dimension is required
5455 for i in range (ndev ):
5556 slicing = tuple (slice (0 , end ) for end in send_buf_shapes [i ])
56- chunks [i ] = chunks [i ].reshape (chunk_shape )[slicing ]
57+ chunks [i ] = chunks [i ].reshape (buffer_chunk_shape )[slicing ]
5758 return chunks
0 commit comments