Skip to content

Commit 95e32bf

Browse files
committed
fix flake8, docstring on _nccl functions and gpu.rst
1 parent 0c7136f commit 95e32bf

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

docs/source/gpu.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
153153
* - :class:`pylops_mpi.optimization.basic.cgls`
154154
- ✅
155155
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
156-
- Planned ⏳
157-
* - ISTA Solver
158-
- Planned ⏳
156+
- ✅
159157
* - Complex Numeric Data Type for NCCL
160-
- Planned ⏳
158+
- ✅
159+
* - ISTA Solver
160+
- Planned ⏳

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111111
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
112112
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
113113
f"Got {x.partition} instead...")
114-
y = DistributedArray(global_shape=self.shape[0],
114+
y = DistributedArray(global_shape=self.shape[0],
115115
base_comm=x.base_comm,
116116
base_comm_nccl=x.base_comm_nccl,
117117
partition=x.partition,
@@ -128,11 +128,6 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
128128
for isl in range(self.nsls[self.rank]):
129129
y1[isl] = ncp.dot(self.G[isl], x[isl])
130130
# gather results
131-
# TODO: _allgather is supposed to be private to DistributedArray
132-
# but so far, we do not take base_comm_nccl as an argument to Op.
133-
# For consistency, y._allgather has to be called here.
134-
# Alternatively, we can also do if-else checking x.base_comm_nccl, but that means
135-
# we have to call function from _nccl.py
136131
y[:] = ncp.vstack(y._allgather(y1)).ravel()
137132
return y
138133

@@ -141,7 +136,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
141136
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
142137
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
143138
f"Got {x.partition} instead...")
144-
y = DistributedArray(global_shape=self.shape[1],
139+
y = DistributedArray(global_shape=self.shape[1],
145140
base_comm=x.base_comm,
146141
base_comm_nccl=x.base_comm_nccl,
147142
partition=x.partition,
@@ -176,8 +171,8 @@ def _rmatvec(self, x: NDArray) -> NDArray:
176171
if self.usematmul and isinstance(recv, ncp.ndarray) :
177172
# unrolling
178173
chunk_size = self.ny * self.nz
179-
num_partition = (len(recv)+chunk_size-1)//chunk_size
180-
recv = ncp.vstack([recv[i*chunk_size: (i+1)*chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
174+
num_partition = (len(recv) + chunk_size - 1) // chunk_size
175+
recv = ncp.vstack([recv[i * chunk_size: (i + 1) * chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
181176
else:
182177
recv = ncp.vstack(recv)
183178
y[:] = recv.ravel()

pylops_mpi/utils/_nccl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ class NcclOp(IntEnum):
3939

4040

4141
def _nccl_buf_size(buf, count=None):
42+
""" Get an appropriate buffer size according to the dtype of buf
43+
44+
Parameters
45+
----------
46+
buf : :obj:`cupy.ndarray` or array-like
47+
The data buffer from the local GPU to be sent.
48+
49+
count : :obj:`int`, optional
50+
Number of elements to send from `buf`, if not sending the every element in `buf`.
51+
Returns:
52+
-------
53+
:obj:`int`
54+
An appropriate number of elements to send from `send_buf` for NCCL communication.
55+
"""
4256
if buf.dtype in ['complex64', 'complex128']:
4357
return 2 * count if count else 2 * buf.size
4458
else:

0 commit comments

Comments
 (0)