Skip to content

Commit 3609220

Browse files
committed
Fredholm NCCL, Broken Fredholm MPI to fix
1 parent 501e581 commit 3609220

File tree

2 files changed

+184
-4
lines changed

2 files changed

+184
-4
lines changed

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111111
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
112112
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
113113
f"Got {x.partition} instead...")
114-
y = DistributedArray(global_shape=self.shape[0], partition=x.partition,
114+
y = DistributedArray(global_shape=self.shape[0],
115+
base_comm=x.base_comm,
116+
base_comm_nccl=x.base_comm_nccl,
117+
partition=x.partition,
115118
engine=x.engine, dtype=self.dtype)
116119
x = x.local_array.reshape(self.dims).squeeze()
117120
x = x[self.islstart[self.rank]:self.islend[self.rank]]
@@ -125,15 +128,25 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
125128
for isl in range(self.nsls[self.rank]):
126129
y1[isl] = ncp.dot(self.G[isl], x[isl])
127130
# gather results
128-
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
131+
# TODO: _allgather is supposed to be private to DistributedArray
132+
# but so far, we do not take base_comm_nccl as an argument to Op.
133+
# For consistency, y._allgather has to be call here.
134+
# we can do if else for x.base_comm_nccl, but that means
135+
# we have to call function from _nccl.py
136+
# y[:] = np.vstack(y._allgather(y1)).ravel()
137+
recv = y._allgather(y1)
138+
y[:] = recv.ravel()
129139
return y
130140

131141
def _rmatvec(self, x: NDArray) -> NDArray:
132142
ncp = get_module(x.engine)
133143
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
134144
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
135145
f"Got {x.partition} instead...")
136-
y = DistributedArray(global_shape=self.shape[1], partition=x.partition,
146+
y = DistributedArray(global_shape=self.shape[1],
147+
base_comm=x.base_comm,
148+
base_comm_nccl=x.base_comm_nccl,
149+
partition=x.partition,
137150
engine=x.engine, dtype=self.dtype)
138151
x = x.local_array.reshape(self.dimsd).squeeze()
139152
x = x[self.islstart[self.rank]:self.islend[self.rank]]
@@ -159,5 +172,11 @@ def _rmatvec(self, x: NDArray) -> NDArray:
159172
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
160173

161174
# gather results
162-
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
175+
recv = y._allgather(y1)
176+
if self.usematmul:
177+
# unrolling like DistributedArray asarray()
178+
chunk_size = self.ny * self.nz
179+
recv = ncp.vstack([recv[i*chunk_size: (i+1)*chunk_size].reshape(self.nz, self.ny).T for i in range((len(recv)+chunk_size-1)//chunk_size)])
180+
181+
y[:] = recv.ravel()
163182
return y

tests_nccl/test_fredholm_nccl.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Test the MPIFredholm1 class
2+
Designed to run with n GPUs (with 1 MPI process per GPU)
3+
$ mpiexec -n 3 pytest test_fredholm_nccl.py --with-mpi
4+
5+
This file employs the same test sets as test_fredholm under NCCL environment
6+
"""
7+
import numpy as np
8+
import cupy as cp
9+
from numpy.testing import assert_allclose
10+
from mpi4py import MPI
11+
import pytest
12+
13+
import pylops
14+
import pylops_mpi
15+
16+
from pylops_mpi import DistributedArray
17+
from pylops_mpi.DistributedArray import local_split, Partition
18+
from pylops_mpi.signalprocessing import MPIFredholm1
19+
from pylops_mpi.utils.dottest import dottest
20+
from pylops_mpi.utils._nccl import initialize_nccl_comm
21+
22+
np.random.seed(42)
23+
rank = MPI.COMM_WORLD.Get_rank()
24+
size = MPI.COMM_WORLD.Get_size()
25+
26+
nccl_comm = initialize_nccl_comm()
27+
28+
par1 = {
29+
"nsl": 12,
30+
"ny": 6,
31+
"nx": 4,
32+
"nz": 5,
33+
"usematmul": False,
34+
"saveGt": True,
35+
"imag": 0,
36+
"dtype": "float32",
37+
} # real, saved Gt
38+
par2 = {
39+
"nsl": 12,
40+
"ny": 6,
41+
"nx": 4,
42+
"nz": 5,
43+
"usematmul": True,
44+
"saveGt": False,
45+
"imag": 0,
46+
"dtype": "float32",
47+
} # real, unsaved Gt
48+
par3 = {
49+
"nsl": 12,
50+
"ny": 6,
51+
"nx": 4,
52+
"nz": 5,
53+
"usematmul": False,
54+
"saveGt": True,
55+
"imag": 1j,
56+
"dtype": "complex64",
57+
} # complex, saved Gt
58+
par4 = {
59+
"nsl": 12,
60+
"ny": 6,
61+
"nx": 4,
62+
"nz": 5,
63+
"saveGt": False,
64+
"usematmul": False,
65+
"imag": 1j,
66+
"dtype": "complex64",
67+
} # complex, unsaved Gt
68+
par5 = {
69+
"nsl": 12,
70+
"ny": 6,
71+
"nx": 4,
72+
"nz": 1,
73+
"usematmul": True,
74+
"saveGt": True,
75+
"imag": 0,
76+
"dtype": "float32",
77+
} # real, saved Gt, nz=1
78+
par6 = {
79+
"nsl": 12,
80+
"ny": 6,
81+
"nx": 4,
82+
"nz": 1,
83+
"usematmul": True,
84+
"saveGt": False,
85+
"imag": 0,
86+
"dtype": "float32",
87+
} # real, unsaved Gt, nz=1
88+
89+
90+
"""Seems to stop next tests from running
91+
@pytest.mark.mpi(min_size=2)
92+
@pytest.mark.parametrize("par", [(par1)])
93+
def test_Gsize1(par):
94+
#Check error is raised if G has size 1 in any of the ranks
95+
with pytest.raises(NotImplementedError):
96+
_ = MPIFredholm1(
97+
np.ones((1, par["nx"], par["ny"])),
98+
nz=par["nz"],
99+
saveGt=par["saveGt"],
100+
usematmul=par["usematmul"],
101+
dtype=par["dtype"],
102+
)
103+
"""
104+
105+
106+
@pytest.mark.mpi(min_size=2)
107+
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
108+
def test_Fredholm1_nccl(par):
109+
"""Fredholm1 operator"""
110+
np.random.seed(42)
111+
112+
_F = cp.arange(par["nsl"] * par["nx"] * par["ny"]).reshape(
113+
par["nsl"], par["nx"], par["ny"]
114+
).astype(par["dtype"])
115+
F = _F - par["imag"] * _F
116+
117+
# split across ranks
118+
nsl_rank = local_split((par["nsl"], ), MPI.COMM_WORLD, Partition.SCATTER, 0)
119+
nsl_ranks = np.concatenate(MPI.COMM_WORLD.allgather(nsl_rank))
120+
islin_rank = np.insert(np.cumsum(nsl_ranks)[:-1] , 0, 0)[rank]
121+
islend_rank = np.cumsum(nsl_ranks)[rank]
122+
Frank = F[islin_rank:islend_rank]
123+
124+
Fop_MPI = MPIFredholm1(
125+
Frank,
126+
nz=par["nz"],
127+
saveGt=par["saveGt"],
128+
usematmul=par["usematmul"],
129+
dtype=par["dtype"],
130+
)
131+
132+
x = DistributedArray(global_shape=par["nsl"] * par["ny"] * par["nz"],
133+
base_comm_nccl=nccl_comm,
134+
partition=pylops_mpi.Partition.BROADCAST,
135+
dtype=par["dtype"],
136+
engine="cupy")
137+
x[:] = 1. + par["imag"] * 1.
138+
x_global = x.asarray()
139+
# Forward
140+
y_dist = Fop_MPI @ x
141+
y = y_dist.asarray()
142+
# Adjoint
143+
y_adj_dist = Fop_MPI.H @ y_dist
144+
y_adj = y_adj_dist.asarray()
145+
# Dot test
146+
dottest(Fop_MPI, x, y_dist, par["nsl"] * par["nx"] * par["nz"], par["nsl"] * par["ny"] * par["nz"])
147+
148+
if rank == 0:
149+
Fop = pylops.signalprocessing.Fredholm1(
150+
F.get(),
151+
nz=par["nz"],
152+
saveGt=par["saveGt"],
153+
usematmul=par["usematmul"],
154+
dtype=par["dtype"],
155+
)
156+
157+
assert Fop_MPI.shape == Fop.shape
158+
y_np = Fop @ x_global.get()
159+
y_adj_np = Fop.H @ y_np
160+
assert_allclose(y.get(), y_np, rtol=1e-14)
161+
assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14)

0 commit comments

Comments
 (0)