Skip to content

Commit fe05882

Browse files
committed
Change param name to buffer_chunk_shape
1 parent 75abc1e commit fe05882

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

pylops_mpi/utils/_common.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77

8-
def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None) -> list:
8+
def _unroll_allgather_recv(recv_buf, buffer_chunk_shape, send_buf_shapes, displs=None) -> list:
99
r"""Unroll recv_buf after Buffered Allgather (MPI and NCCL)
1010
1111
Depending on the provided parameters, the function:
@@ -18,14 +18,14 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
1818
Parameters
1919
----------
2020
recv_buf: :obj:`cupy.ndarray` or array-like
21-
The data buffer returned from nccl_allgather call
21+
The data buffer returned from the allgather call
2222
send_buf_shapes: :obj:`list`
2323
A list of original shapes of each rank's send_buf before any padding.
24-
chunk_shape : tuple
25-
Shape of each gathered chunk in recv_buf. This must match the shape
26-
used to construct the gathered buffer: use the padded send buffer shape
27-
when padding is required (e.g., nccl_allgather with padding), or the original send buffer
28-
shape when no padding is used.
24+
buffer_chunk_shape : tuple
25+
Shape of each rank’s data as stored in ``recv_buf``. This should match
26+
the layout used during allgather: use the padded send buffer shape when
27+
padding is applied (e.g., NCCL), or the original send buffer shape when
28+
no padding is used.
2929
displs : list, optional
3030
Starting offsets in recv_buf for each rank's data, used when chunks have
3131
variable sizes (e.g., mpi_allgather with displacements).
@@ -44,7 +44,8 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
4444
for i in range(ndev)
4545
]
4646
else:
47-
chunk_size = np.prod(chunk_shape)
47+
# extract an individual array from each device
48+
chunk_size = np.prod(buffer_chunk_shape)
4849
chunks = [
4950
recv_buf[i * chunk_size:(i + 1) * chunk_size]
5051
for i in range(ndev)
@@ -53,5 +54,5 @@ def _unroll_allgather_recv(recv_buf, chunk_shape, send_buf_shapes, displs=None)
5354
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
5455
for i in range(ndev):
5556
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
56-
chunks[i] = chunks[i].reshape(chunk_shape)[slicing]
57+
chunks[i] = chunks[i].reshape(buffer_chunk_shape)[slicing]
5758
return chunks

0 commit comments

Comments
 (0)