Skip to content

Commit f80417d

Browse files
committed
nccl for HStack Op
1 parent 7ae78a9 commit f80417d

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

pylops_mpi/basicoperators/HStack.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pylops.utils import DTypeLike
66

77
from pylops_mpi import DistributedArray, MPILinearOperator
8+
from pylops_mpi.DistributedArray import NcclCommunicatorType
89
from .VStack import MPIVStack
910

1011

@@ -89,14 +90,15 @@ class MPIHStack(MPILinearOperator):
8990

9091
def __init__(self, ops: Sequence[LinearOperator],
9192
base_comm: MPI.Comm = MPI.COMM_WORLD,
93+
base_comm_nccl: NcclCommunicatorType = None,
9294
dtype: Optional[DTypeLike] = None):
9395
self.ops = ops
9496
nops = [oper.shape[0] for oper in self.ops]
9597
nops = np.concatenate(base_comm.allgather(nops), axis=0)
9698
if len(set(nops)) > 1:
9799
raise ValueError("Operators have different number of rows")
98100
hops = [oper.H for oper in self.ops]
99-
self.HStack = MPIVStack(ops=hops, base_comm=base_comm, dtype=dtype).H
101+
self.HStack = MPIVStack(ops=hops, base_comm=base_comm, base_comm_nccl=base_comm_nccl, dtype=dtype).H
100102
super().__init__(shape=self.HStack.shape, dtype=self.HStack.dtype, base_comm=base_comm)
101103

102104
def _matvec(self, x: DistributedArray) -> DistributedArray:

tests_nccl/test_stack_nccl.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,46 @@ def test_stacked_vstack_nccl(par):
119119
assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14)
120120

121121

122-
# TODO: Test of HStack
122+
@pytest.mark.mpi(min_size=2)
123+
@pytest.mark.parametrize("par", [(par1), (par2)])
124+
def test_hstack(par):
125+
"""Test the MPIHStack operator with NCCL"""
126+
size = MPI.COMM_WORLD.Get_size()
127+
rank = MPI.COMM_WORLD.Get_rank()
128+
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
129+
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
130+
HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], base_comm_nccl=nccl_comm)
131+
132+
# Scattered DistributedArray
133+
x = pylops_mpi.DistributedArray(global_shape=size * par['nx'],
134+
base_comm_nccl=nccl_comm,
135+
partition=pylops_mpi.Partition.SCATTER,
136+
dtype=par['dtype'],
137+
engine="cupy")
138+
x[:] = cp.ones(shape=par['nx'], dtype=par['dtype'])
139+
x_global = x.asarray()
140+
141+
# Broadcasted DistributedArray(global_shape == local_shape)
142+
y = pylops_mpi.DistributedArray(global_shape=par['ny'],
143+
base_comm_nccl=nccl_comm,
144+
partition=pylops_mpi.Partition.BROADCAST,
145+
dtype=par['dtype'],
146+
engine="cupy")
147+
y[:] = cp.ones(shape=par['ny'], dtype=par['dtype'])
148+
y_global = y.asarray()
149+
150+
x_mat = HStack_MPI @ x
151+
y_rmat = HStack_MPI.H @ y
152+
assert isinstance(x_mat, pylops_mpi.DistributedArray)
153+
assert isinstance(y_rmat, pylops_mpi.DistributedArray)
154+
155+
x_mat_mpi = x_mat.asarray()
156+
y_rmat_mpi = y_rmat.asarray()
157+
158+
if rank == 0:
159+
ops = [pylops.MatrixMult(A=((i + 1) * A_gpu.get()).astype(par['dtype'])) for i in range(size)]
160+
HStack = pylops.HStack(ops=ops)
161+
x_mat_np = HStack @ x_global.get()
162+
y_rmat_np = HStack.H @ y_global.get()
163+
assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14)
164+
assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14)

0 commit comments

Comments
 (0)