Skip to content

Commit 2f78648

Browse files
committed
fix explicit copy, create tutorials_nccl/
1 parent 36cb8ba commit 2f78648

File tree

3 files changed

+211
-9
lines changed

3 files changed

+211
-9
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,7 @@ run_examples:
6868
# Run tutorials using mpi
6969
run_tutorials:
7070
sh mpi_examples.sh tutorials $(NUM_PROCESSES)
71+
72+
# Run tutorials using nccl
73+
run_tutorials_nccl:
74+
sh mpi_examples.sh tutorials_nccl $(NUM_PROCESSES)

pylops_mpi/optimization/cls_basic.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from pylops.optimization.basesolver import Solver
7-
from pylops.utils import NDArray, get_module
7+
from pylops.utils import NDArray
88

99
from pylops_mpi import DistributedArray, StackedDistributedArray
1010

@@ -98,10 +98,8 @@ def setup(
9898

9999
if show and self.rank == 0:
100100
if isinstance(x, StackedDistributedArray):
101-
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
102-
# so it has to be made asarray first
103-
ncp = get_module(x.distarrays[0].engine)
104-
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
101+
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
102+
self._print_setup(is_complex)
105103
else:
106104
self._print_setup(np.iscomplexobj(x.local_array))
107105
return x
@@ -357,10 +355,8 @@ def setup(self,
357355
# print setup
358356
if show and self.rank == 0:
359357
if isinstance(x, StackedDistributedArray):
360-
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
361-
# so it has to be made asarray first
362-
ncp = get_module(x.distarrays[0].engine)
363-
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
358+
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
359+
self._print_setup(is_complex)
364360
else:
365361
self._print_setup(np.iscomplexobj(x.local_array))
366362
return x

tutorials_nccl/poststack_nccl.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
r"""
2+
Post Stack Inversion - 3D with NCCL
3+
============================================
4+
This tutorial is an extension of the :ref:`sphx_glr_tutorials_poststack.py` tutorial where PyLops-MPI is run in multi-GPU setting with GPUs communicating via NCCL.
5+
"""
6+
7+
import numpy as np
8+
import cupy as cp
9+
from scipy.signal import filtfilt
10+
from matplotlib import pyplot as plt
11+
from mpi4py import MPI
12+
13+
from pylops.utils.wavelets import ricker
14+
from pylops.basicoperators import Transpose
15+
from pylops.avo.poststack import PoststackLinearModelling
16+
17+
import pylops_mpi
18+
import pylops_mpi.utils
19+
import pylops_mpi.utils._nccl
20+
21+
###############################################################################
22+
# NCCL communication can be easily initialized with
23+
# :py:func:`pylops_mpi.utils._nccl.initialize_nccl_comm` operator.
24+
# One can think of this as GPU-counterpart of :code:`MPI.COMM_WORLD`
25+
26+
plt.close("all")
27+
nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm()
28+
rank = MPI.COMM_WORLD.Get_rank()
29+
30+
###############################################################################
31+
# Let's start by defining all the parameters required by the
32+
# :py:func:`pylops.avo.poststack.PoststackLinearModelling` operator.
33+
# Note that this section is exactly the same as the one in the MPI example as we will keep using MPI for transfering metadata (i.e., shapes, dims, etc.)
34+
35+
# Model
36+
model = np.load("../testdata/avo/poststack_model.npz")
37+
x, z, m = model['x'][::3], model['z'], np.log(model['model'])[:, ::3]
38+
39+
# Making m a 3D model
40+
ny_i = 20 # size of model in y direction for rank i
41+
y = np.arange(ny_i)
42+
m3d_i = np.tile(m[:, :, np.newaxis], (1, 1, ny_i)).transpose((2, 1, 0))
43+
ny_i, nx, nz = m3d_i.shape
44+
45+
# Size of y at all ranks
46+
ny = MPI.COMM_WORLD.allreduce(ny_i)
47+
48+
# Smooth model
49+
nsmoothy, nsmoothx, nsmoothz = 5, 30, 20
50+
mback3d_i = filtfilt(np.ones(nsmoothy) / float(nsmoothy), 1, m3d_i, axis=0)
51+
mback3d_i = filtfilt(np.ones(nsmoothx) / float(nsmoothx), 1, mback3d_i, axis=1)
52+
mback3d_i = filtfilt(np.ones(nsmoothz) / float(nsmoothz), 1, mback3d_i, axis=2)
53+
54+
# Wavelet
55+
dt = 0.004
56+
t0 = np.arange(nz) * dt
57+
ntwav = 41
58+
wav = ricker(t0[:ntwav // 2 + 1], 15)[0]
59+
60+
# Collecting all the m3d and mback3d at all ranks
61+
m3d = np.concatenate(MPI.COMM_WORLD.allgather(m3d_i))
62+
mback3d = np.concatenate(MPI.COMM_WORLD.allgather(mback3d_i))
63+
64+
###############################################################################
65+
# We are now ready to initialize various :py:class:`pylops_mpi.DistributedArray` objects.
66+
# Compared to the MPI tutorial, we need to make sure that we pass :code:`base_comm_nccl = nccl_comm` and set CuPy as the engine.
67+
68+
m3d_dist = pylops_mpi.DistributedArray(global_shape=ny * nx * nz, base_comm_nccl=nccl_comm, engine="cupy")
69+
m3d_dist[:] = cp.asarray(m3d_i.flatten())
70+
71+
# Do the same thing for smooth model
72+
mback3d_dist = pylops_mpi.DistributedArray(global_shape=ny * nx * nz, base_comm_nccl=nccl_comm, engine="cupy")
73+
mback3d_dist[:] = cp.asarray(mback3d_i.flatten())
74+
75+
###############################################################################
76+
# For PostStackLinearModelling, there is no change needed to have it run with NCCL.
77+
# This PyLops operator has GPU-support (https://pylops.readthedocs.io/en/stable/gpu.html)
78+
# so it can run with DistributedArray whose engine is Cupy
79+
80+
PPop = PoststackLinearModelling(wav, nt0=nz, spatdims=(ny_i, nx))
81+
Top = Transpose((ny_i, nx, nz), (2, 0, 1))
82+
BDiag = pylops_mpi.basicoperators.MPIBlockDiag(ops=[Top.H @ PPop @ Top, ])
83+
84+
###############################################################################
85+
# This computation will be done in GPU. The call :code:`asarray()` triggers the NCCL communication (gather result from each GPU).
86+
# But array :code:`d` and :code:`d_0` still live in GPU memory
87+
88+
d_dist = BDiag @ m3d_dist
89+
d_local = d_dist.local_array.reshape((ny_i, nx, nz))
90+
d = d_dist.asarray().reshape((ny, nx, nz))
91+
d_0_dist = BDiag @ mback3d_dist
92+
d_0 = d_dist.asarray().reshape((ny, nx, nz))
93+
94+
###############################################################################
95+
# Inversion using CGLS solver - There is no code change to have run on NCCL (it handles though MPI operator and DistributedArray)
96+
# In this particular case, the local computation will be done in GPU. Collective communication calls
97+
# will be carried through NCCL GPU-to-GPU.
98+
99+
# Inversion using CGLS solver
100+
minv3d_iter_dist = pylops_mpi.optimization.basic.cgls(BDiag, d_dist, x0=mback3d_dist, niter=10, show=True)[0]
101+
minv3d_iter = minv3d_iter_dist.asarray().reshape((ny, nx, nz))
102+
103+
###############################################################################
104+
105+
# Regularized inversion with normal equations
106+
epsR = 1e2
107+
LapOp = pylops_mpi.MPILaplacian(dims=(ny, nx, nz), axes=(0, 1, 2), weights=(1, 1, 1),
108+
sampling=(1, 1, 1), dtype=BDiag.dtype)
109+
NormEqOp = BDiag.H @ BDiag + epsR * LapOp.H @ LapOp
110+
dnorm_dist = BDiag.H @ d_dist
111+
minv3d_ne_dist = pylops_mpi.optimization.basic.cg(NormEqOp, dnorm_dist, x0=mback3d_dist, niter=10, show=True)[0]
112+
minv3d_ne = minv3d_ne_dist.asarray().reshape((ny, nx, nz))
113+
114+
###############################################################################
115+
116+
# Regularized inversion with regularized equations
117+
StackOp = pylops_mpi.MPIStackedVStack([BDiag, np.sqrt(epsR) * LapOp])
118+
d0_dist = pylops_mpi.DistributedArray(global_shape=ny * nx * nz, base_comm_nccl=nccl_comm, engine="cupy")
119+
d0_dist[:] = 0.
120+
dstack_dist = pylops_mpi.StackedDistributedArray([d_dist, d0_dist])
121+
122+
dnorm_dist = BDiag.H @ d_dist
123+
minv3d_reg_dist = pylops_mpi.optimization.basic.cgls(StackOp, dstack_dist, x0=mback3d_dist, niter=10, show=True)[0]
124+
minv3d_reg = minv3d_reg_dist.asarray().reshape((ny, nx, nz))
125+
126+
###############################################################################
127+
# To plot the inversion results, the array must be copied back to cpu via :code:`get()`
128+
129+
if rank == 0:
130+
# Check the distributed implementation gives the same result
131+
# as the one running only on rank0
132+
PPop0 = PoststackLinearModelling(wav, nt0=nz, spatdims=(ny, nx))
133+
d0 = (PPop0 @ m3d.transpose(2, 0, 1)).transpose(1, 2, 0)
134+
d0_0 = (PPop0 @ m3d.transpose(2, 0, 1)).transpose(1, 2, 0)
135+
136+
# Check the two distributed implementations give the same modelling results
137+
print('Distr == Local', np.allclose(d, d0))
138+
print('Smooth Distr == Local', np.allclose(d_0, d0_0))
139+
140+
# Visualize
141+
fig, axs = plt.subplots(nrows=6, ncols=3, figsize=(9, 14), constrained_layout=True)
142+
axs[0][0].imshow(m3d[5, :, :].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
143+
axs[0][0].set_title("Model x-z")
144+
axs[0][0].axis("tight")
145+
axs[0][1].imshow(m3d[:, 200, :].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
146+
axs[0][1].set_title("Model y-z")
147+
axs[0][1].axis("tight")
148+
axs[0][2].imshow(m3d[:, :, 220].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
149+
axs[0][2].set_title("Model y-z")
150+
axs[0][2].axis("tight")
151+
152+
axs[1][0].imshow(mback3d[5, :, :].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
153+
axs[1][0].set_title("Smooth Model x-z")
154+
axs[1][0].axis("tight")
155+
axs[1][1].imshow(mback3d[:, 200, :].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
156+
axs[1][1].set_title("Smooth Model y-z")
157+
axs[1][1].axis("tight")
158+
axs[1][2].imshow(mback3d[:, :, 220].T, cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
159+
axs[1][2].set_title("Smooth Model y-z")
160+
axs[1][2].axis("tight")
161+
162+
axs[2][0].imshow(d[5, :, :].T.get(), cmap="gray", vmin=-1, vmax=1)
163+
axs[2][0].set_title("Data x-z")
164+
axs[2][0].axis("tight")
165+
axs[2][1].imshow(d[:, 200, :].T.get(), cmap='gray', vmin=-1, vmax=1)
166+
axs[2][1].set_title('Data y-z')
167+
axs[2][1].axis('tight')
168+
axs[2][2].imshow(d[:, :, 220].T.get(), cmap='gray', vmin=-1, vmax=1)
169+
axs[2][2].set_title('Data x-y')
170+
axs[2][2].axis('tight')
171+
172+
axs[3][0].imshow(minv3d_iter[5, :, :].T.get(), cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
173+
axs[3][0].set_title("Inverted Model iter x-z")
174+
axs[3][0].axis("tight")
175+
axs[3][1].imshow(minv3d_iter[:, 200, :].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
176+
axs[3][1].set_title('Inverted Model iter y-z')
177+
axs[3][1].axis('tight')
178+
axs[3][2].imshow(minv3d_iter[:, :, 220].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
179+
axs[3][2].set_title('Inverted Model iter x-y')
180+
axs[3][2].axis('tight')
181+
182+
axs[4][0].imshow(minv3d_ne[5, :, :].T.get(), cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
183+
axs[4][0].set_title("Normal Equations Inverted Model iter x-z")
184+
axs[4][0].axis("tight")
185+
axs[4][1].imshow(minv3d_ne[:, 200, :].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
186+
axs[4][1].set_title('Normal Equations Inverted Model iter y-z')
187+
axs[4][1].axis('tight')
188+
axs[4][2].imshow(minv3d_ne[:, :, 220].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
189+
axs[4][2].set_title('Normal Equations Inverted Model iter x-y')
190+
axs[4][2].axis('tight')
191+
192+
axs[5][0].imshow(minv3d_reg[5, :, :].T.get(), cmap="gist_rainbow", vmin=m.min(), vmax=m.max())
193+
axs[5][0].set_title("Regularized Inverted Model iter x-z")
194+
axs[5][0].axis("tight")
195+
axs[5][1].imshow(minv3d_reg[:, 200, :].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
196+
axs[5][1].set_title('Regularized Inverted Model iter y-z')
197+
axs[5][1].axis('tight')
198+
axs[5][2].imshow(minv3d_reg[:, :, 220].T.get(), cmap='gist_rainbow', vmin=m.min(), vmax=m.max())
199+
axs[5][2].set_title('Regularized Inverted Model iter x-y')
200+
axs[5][2].axis('tight')
201+
202+
plt.savefig("./poststack_inv_nccl.png")

0 commit comments

Comments
 (0)