Skip to content

Commit 006ed1d

Browse files
committed
tutorials for LSM and MDD using NCCL
1 parent 19e873a commit 006ed1d

File tree

2 files changed

+445
-0
lines changed

2 files changed

+445
-0
lines changed

tutorials_nccl/lsm_nccl.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
r"""
2+
Least-squares Migration with NCCL
3+
=================================
4+
This tutorial is an extension of the :ref:`sphx_glr_tutorials_lsm.py`
5+
tutorial where PyLops-MPI is run in multi-GPU setting with GPUs communicating
6+
via NCCL.
7+
"""
8+
9+
import warnings
10+
warnings.filterwarnings('ignore')
11+
12+
import numpy as np
13+
import cupy as cp
14+
from matplotlib import pyplot as plt
15+
from mpi4py import MPI
16+
17+
from pylops.utils.wavelets import ricker
18+
from pylops.waveeqprocessing.lsm import LSM
19+
20+
import pylops_mpi
21+
22+
###############################################################################
23+
# NCCL communication can be easily initialized with
24+
# :py:func:`pylops_mpi.utils._nccl.initialize_nccl_comm` operator.
25+
# One can think of this as GPU-counterpart of :code:`MPI.COMM_WORLD`
26+
27+
np.random.seed(42)
28+
plt.close("all")
29+
nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm()
30+
rank = MPI.COMM_WORLD.Get_rank()
31+
32+
###############################################################################
33+
# Let's start by defining all the parameters required by the
34+
# :py:class:`pylops.waveeqprocessing.LSM` operator.
35+
# Note that this section is exactly the same as the one in the MPI example
36+
# as we will keep using MPI for transfering metadata (i.e., shapes, dims, etc.)
37+
38+
# Velocity Model
39+
nx, nz = 81, 60
40+
dx, dz = 4, 4
41+
x, z = np.arange(nx) * dx, np.arange(nz) * dz
42+
v0 = 1000 # initial velocity
43+
kv = 0.0 # gradient
44+
vel = np.outer(np.ones(nx), v0 + kv * z)
45+
46+
# Reflectivity Model
47+
refl = np.zeros((nx, nz))
48+
refl[:, 30] = -1
49+
refl[:, 50] = 0.5
50+
51+
# Receivers
52+
nr = 11
53+
rx = np.linspace(10 * dx, (nx - 10) * dx, nr)
54+
rz = 20 * np.ones(nr)
55+
recs = np.vstack((rx, rz))
56+
57+
# Sources
58+
ns = 10
59+
# Total number of sources at all ranks
60+
nstot = MPI.COMM_WORLD.allreduce(ns, op=MPI.SUM)
61+
sxtot = np.linspace(dx * 10, (nx - 10) * dx, nstot)
62+
sx = sxtot[rank * ns: (rank + 1) * ns]
63+
sztot = 10 * np.ones(nstot)
64+
sz = 10 * np.ones(ns)
65+
sources = np.vstack((sx, sz))
66+
sources_tot = np.vstack((sxtot, sztot))
67+
68+
if rank == 0:
69+
plt.figure(figsize=(10, 5))
70+
im = plt.imshow(vel.T, cmap="summer", extent=(x[0], x[-1], z[-1], z[0]))
71+
plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
72+
plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k")
73+
cb = plt.colorbar(im)
74+
cb.set_label("[m/s]")
75+
plt.axis("tight")
76+
plt.xlabel("x [m]"), plt.ylabel("z [m]")
77+
plt.title("Velocity")
78+
plt.xlim(x[0], x[-1])
79+
plt.tight_layout()
80+
81+
plt.figure(figsize=(10, 5))
82+
im = plt.imshow(refl.T, cmap="gray", extent=(x[0], x[-1], z[-1], z[0]))
83+
plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
84+
plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k")
85+
plt.colorbar(im)
86+
plt.axis("tight")
87+
plt.xlabel("x [m]"), plt.ylabel("z [m]")
88+
plt.title("Reflectivity")
89+
plt.xlim(x[0], x[-1])
90+
plt.tight_layout()
91+
92+
###############################################################################
93+
# We create a :py:class:`pylops.waveeqprocessing.LSM` at each rank and then push them
94+
# into a :py:class:`pylops_mpi.basicoperators.MPIVStack` to perform a matrix-vector
95+
# product with the broadcasted reflectivity at every location on the subsurface.
96+
# Also, we must pass `nccl_comm` to `refl` in order to use NCCL for communications.
97+
# Noted that we allocate some arrays (wav, lsm.Demop.trav_srcs, and lsm.Demop.trav.recs)
98+
# to GPU upfront. Because we want a fair performace comparison, we avoid having
99+
# LSM internally copying arrays.
100+
101+
# Wavelet
102+
nt = 651
103+
dt = 0.004
104+
t = np.arange(nt) * dt
105+
wav, wavt, wavc = ricker(t[:41], f0=20)
106+
107+
lsm = LSM(
108+
z,
109+
x,
110+
t,
111+
sources,
112+
recs,
113+
v0,
114+
cp.asarray(wav.astype(np.float32)),
115+
wavc,
116+
mode="analytic",
117+
engine="cuda",
118+
dtype=np.float32
119+
)
120+
lsm.Demop.trav_srcs = cp.asarray(lsm.Demop.trav_srcs.astype(np.float32))
121+
lsm.Demop.trav_recs = cp.asarray(lsm.Demop.trav_recs.astype(np.float32))
122+
123+
VStack = pylops_mpi.MPIVStack(ops=[lsm.Demop, ])
124+
refl_dist = pylops_mpi.DistributedArray(global_shape=nx * nz,
125+
partition=pylops_mpi.Partition.BROADCAST,
126+
base_comm_nccl=nccl_comm,
127+
engine="cupy")
128+
refl_dist[:] = cp.asarray(refl.flatten())
129+
d_dist = VStack @ refl_dist
130+
d = d_dist.asarray().reshape((nstot, nr, nt))
131+
132+
###############################################################################
133+
# We calculate now the adjoint and model the data using the adjoint reflectivity
134+
# as input.
135+
madj_dist = VStack.H @ d_dist
136+
madj = madj_dist.asarray().reshape((nx, nz))
137+
d_adj_dist = VStack @ madj_dist
138+
d_adj = d_adj_dist.asarray().reshape((nstot, nr, nt))
139+
140+
###############################################################################
141+
# We calculate the inverse using the :py:func:`pylops_mpi.optimization.basic.cgls`
142+
# solver. Here, we pass the `nccl_comm` to `x0` to use NCCL as a communicator.
143+
# In this particular case, the local computation will be done in GPU.
144+
# Collective communication calls will be carried through NCCL GPU-to-GPU.
145+
146+
# Inverse
147+
# Initializing x0 to zeroes
148+
x0 = pylops_mpi.DistributedArray(VStack.shape[1],
149+
partition=pylops_mpi.Partition.BROADCAST,
150+
base_comm_nccl=nccl_comm,
151+
engine="cupy")
152+
x0[:] = 0
153+
minv_dist = pylops_mpi.cgls(VStack, d_dist, x0=x0, niter=100, show=True)[0]
154+
minv = minv_dist.asarray().reshape((nx, nz))
155+
d_inv_dist = VStack @ minv_dist
156+
d_inv = d_inv_dist.asarray().reshape(nstot, nr, nt)
157+
158+
##############################################################################
159+
# Finally we visualize the results. Note that the array must be copied back
160+
# to the CPU by calling the :code:`get()` method on the CuPy arrays.
161+
162+
if rank == 0:
163+
# Visualize
164+
fig1, axs = plt.subplots(1, 3, figsize=(10, 3))
165+
axs[0].imshow(refl.T, cmap="gray", vmin=-1, vmax=1)
166+
axs[0].axis("tight")
167+
axs[0].set_title(r"$m$")
168+
axs[1].imshow(madj.T.get(), cmap="gray", vmin=-madj.max(), vmax=madj.max())
169+
axs[1].set_title(r"$m_{adj}$")
170+
axs[1].axis("tight")
171+
axs[2].imshow(minv.T.get(), cmap="gray", vmin=-1, vmax=1)
172+
axs[2].axis("tight")
173+
axs[2].set_title(r"$m_{inv}$")
174+
plt.tight_layout()
175+
fig1.savefig("model.png")
176+
177+
fig2, axs = plt.subplots(1, 3, figsize=(10, 3))
178+
axs[0].imshow(d[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
179+
axs[0].set_title(r"$d$")
180+
axs[0].axis("tight")
181+
axs[1].imshow(d_adj[0, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max())
182+
axs[1].set_title(r"$d_{adj}$")
183+
axs[1].axis("tight")
184+
axs[2].imshow(d_inv[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
185+
axs[2].set_title(r"$d_{inv}$")
186+
axs[2].axis("tight")
187+
fig2.savefig("data1.png")
188+
189+
fig3, axs = plt.subplots(1, 3, figsize=(10, 3))
190+
axs[0].imshow(d[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
191+
axs[0].set_title(r"$d$")
192+
axs[0].axis("tight")
193+
axs[1].imshow(d_adj[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max())
194+
axs[1].set_title(r"$d_{adj}$")
195+
axs[1].axis("tight")
196+
axs[2].imshow(d_inv[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
197+
axs[2].set_title(r"$d_{inv}$")
198+
axs[2].axis("tight")
199+
plt.tight_layout()
200+
fig3.savefig("data2.png")

0 commit comments

Comments
 (0)