|  | 
|  | 1 | +r""" | 
|  | 2 | +Least-squares Migration with NCCL | 
|  | 3 | +================================= | 
|  | 4 | +This tutorial is an extension of the :ref:`sphx_glr_tutorials_lsm.py` | 
|  | 5 | +tutorial where PyLops-MPI is run in multi-GPU setting with GPUs communicating | 
|  | 6 | +via NCCL. | 
|  | 7 | +""" | 
|  | 8 | + | 
|  | 9 | +import warnings | 
|  | 10 | +warnings.filterwarnings('ignore') | 
|  | 11 | + | 
|  | 12 | +import numpy as np | 
|  | 13 | +import cupy as cp | 
|  | 14 | +from matplotlib import pyplot as plt | 
|  | 15 | +from mpi4py import MPI | 
|  | 16 | + | 
|  | 17 | +from pylops.utils.wavelets import ricker | 
|  | 18 | +from pylops.waveeqprocessing.lsm import LSM | 
|  | 19 | + | 
|  | 20 | +import pylops_mpi | 
|  | 21 | + | 
|  | 22 | +############################################################################### | 
|  | 23 | +# NCCL communication can be easily initialized with | 
|  | 24 | +# :py:func:`pylops_mpi.utils._nccl.initialize_nccl_comm` operator. | 
|  | 25 | +# One can think of this as GPU-counterpart of :code:`MPI.COMM_WORLD` | 
|  | 26 | + | 
|  | 27 | +np.random.seed(42) | 
|  | 28 | +plt.close("all") | 
|  | 29 | +nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm() | 
|  | 30 | +rank = MPI.COMM_WORLD.Get_rank() | 
|  | 31 | + | 
|  | 32 | +############################################################################### | 
|  | 33 | +# Let's start by defining all the parameters required by the | 
|  | 34 | +# :py:class:`pylops.waveeqprocessing.LSM` operator. | 
|  | 35 | +# Note that this section is exactly the same as the one in the MPI example | 
|  | 36 | +# as we will keep using MPI for transfering metadata (i.e., shapes, dims, etc.) | 
|  | 37 | + | 
|  | 38 | +# Velocity Model | 
|  | 39 | +nx, nz = 81, 60 | 
|  | 40 | +dx, dz = 4, 4 | 
|  | 41 | +x, z = np.arange(nx) * dx, np.arange(nz) * dz | 
|  | 42 | +v0 = 1000  # initial velocity | 
|  | 43 | +kv = 0.0  # gradient | 
|  | 44 | +vel = np.outer(np.ones(nx), v0 + kv * z) | 
|  | 45 | + | 
|  | 46 | +# Reflectivity Model | 
|  | 47 | +refl = np.zeros((nx, nz)) | 
|  | 48 | +refl[:, 30] = -1 | 
|  | 49 | +refl[:, 50] = 0.5 | 
|  | 50 | + | 
|  | 51 | +# Receivers | 
|  | 52 | +nr = 11 | 
|  | 53 | +rx = np.linspace(10 * dx, (nx - 10) * dx, nr) | 
|  | 54 | +rz = 20 * np.ones(nr) | 
|  | 55 | +recs = np.vstack((rx, rz)) | 
|  | 56 | + | 
|  | 57 | +# Sources | 
|  | 58 | +ns = 10 | 
|  | 59 | +# Total number of sources at all ranks | 
|  | 60 | +nstot = MPI.COMM_WORLD.allreduce(ns, op=MPI.SUM) | 
|  | 61 | +sxtot = np.linspace(dx * 10, (nx - 10) * dx, nstot) | 
|  | 62 | +sx = sxtot[rank * ns: (rank + 1) * ns] | 
|  | 63 | +sztot = 10 * np.ones(nstot) | 
|  | 64 | +sz = 10 * np.ones(ns) | 
|  | 65 | +sources = np.vstack((sx, sz)) | 
|  | 66 | +sources_tot = np.vstack((sxtot, sztot)) | 
|  | 67 | + | 
|  | 68 | +if rank == 0: | 
|  | 69 | +    plt.figure(figsize=(10, 5)) | 
|  | 70 | +    im = plt.imshow(vel.T, cmap="summer", extent=(x[0], x[-1], z[-1], z[0])) | 
|  | 71 | +    plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k") | 
|  | 72 | +    plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k") | 
|  | 73 | +    cb = plt.colorbar(im) | 
|  | 74 | +    cb.set_label("[m/s]") | 
|  | 75 | +    plt.axis("tight") | 
|  | 76 | +    plt.xlabel("x [m]"), plt.ylabel("z [m]") | 
|  | 77 | +    plt.title("Velocity") | 
|  | 78 | +    plt.xlim(x[0], x[-1]) | 
|  | 79 | +    plt.tight_layout() | 
|  | 80 | + | 
|  | 81 | +    plt.figure(figsize=(10, 5)) | 
|  | 82 | +    im = plt.imshow(refl.T, cmap="gray", extent=(x[0], x[-1], z[-1], z[0])) | 
|  | 83 | +    plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k") | 
|  | 84 | +    plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k") | 
|  | 85 | +    plt.colorbar(im) | 
|  | 86 | +    plt.axis("tight") | 
|  | 87 | +    plt.xlabel("x [m]"), plt.ylabel("z [m]") | 
|  | 88 | +    plt.title("Reflectivity") | 
|  | 89 | +    plt.xlim(x[0], x[-1]) | 
|  | 90 | +    plt.tight_layout() | 
|  | 91 | + | 
|  | 92 | +############################################################################### | 
|  | 93 | +# We create a :py:class:`pylops.waveeqprocessing.LSM` at each rank and then push them | 
|  | 94 | +# into a :py:class:`pylops_mpi.basicoperators.MPIVStack` to perform a matrix-vector | 
|  | 95 | +# product with the broadcasted reflectivity at every location on the subsurface. | 
|  | 96 | +# Note that we must use :code:`engine="cuda"` and move the wavelet wav to the GPU prior to creating the operator. | 
|  | 97 | +# Moreover, we allocate the traveltime tables (:code:`lsm.Demop.trav_srcs`, and :code:`lsm.Demop.trav_recs`) | 
|  | 98 | +# to the GPU prior to applying the operator to avoid incurring in the penalty of performing | 
|  | 99 | +# host-to-device memory copies every time the operator is applied. Moreover, we must pass :code:`nccl_comm` | 
|  | 100 | +# to the DistributedArray constructor used to create :code:`refl_dist` in order to use NCCL for communications. | 
|  | 101 | + | 
|  | 102 | +# Wavelet | 
|  | 103 | +nt = 651 | 
|  | 104 | +dt = 0.004 | 
|  | 105 | +t = np.arange(nt) * dt | 
|  | 106 | +wav, wavt, wavc = ricker(t[:41], f0=20) | 
|  | 107 | + | 
|  | 108 | +lsm = LSM( | 
|  | 109 | +    z, | 
|  | 110 | +    x, | 
|  | 111 | +    t, | 
|  | 112 | +    sources, | 
|  | 113 | +    recs, | 
|  | 114 | +    v0, | 
|  | 115 | +    cp.asarray(wav.astype(np.float32)), | 
|  | 116 | +    wavc, | 
|  | 117 | +    mode="analytic", | 
|  | 118 | +    engine="cuda", | 
|  | 119 | +    dtype=np.float32 | 
|  | 120 | +) | 
|  | 121 | +lsm.Demop.trav_srcs = cp.asarray(lsm.Demop.trav_srcs.astype(np.float32)) | 
|  | 122 | +lsm.Demop.trav_recs = cp.asarray(lsm.Demop.trav_recs.astype(np.float32)) | 
|  | 123 | + | 
|  | 124 | +VStack = pylops_mpi.MPIVStack(ops=[lsm.Demop, ]) | 
|  | 125 | +refl_dist = pylops_mpi.DistributedArray(global_shape=nx * nz, | 
|  | 126 | +                                        partition=pylops_mpi.Partition.BROADCAST, | 
|  | 127 | +                                        base_comm_nccl=nccl_comm, | 
|  | 128 | +                                        engine="cupy") | 
|  | 129 | +refl_dist[:] = cp.asarray(refl.flatten()) | 
|  | 130 | +d_dist = VStack @ refl_dist | 
|  | 131 | +d = d_dist.asarray().reshape((nstot, nr, nt)) | 
|  | 132 | + | 
|  | 133 | +############################################################################### | 
|  | 134 | +# We calculate now the adjoint and model the data using the adjoint reflectivity | 
|  | 135 | +# as input. | 
|  | 136 | +madj_dist = VStack.H @ d_dist | 
|  | 137 | +madj = madj_dist.asarray().reshape((nx, nz)) | 
|  | 138 | +d_adj_dist = VStack @ madj_dist | 
|  | 139 | +d_adj = d_adj_dist.asarray().reshape((nstot, nr, nt)) | 
|  | 140 | + | 
|  | 141 | +############################################################################### | 
|  | 142 | +# We calculate the inverse using the :py:func:`pylops_mpi.optimization.basic.cgls` | 
|  | 143 | +# solver. Here, we pass the :code:`nccl_comm` to :code:`x0` to use NCCL as a communicator. | 
|  | 144 | +# In this particular case, the local computation will be done in GPU. | 
|  | 145 | +# Collective communication calls will be carried through NCCL GPU-to-GPU. | 
|  | 146 | + | 
|  | 147 | +# Inverse | 
|  | 148 | +# Initializing x0 to zeroes | 
|  | 149 | +x0 = pylops_mpi.DistributedArray(VStack.shape[1], | 
|  | 150 | +                                 partition=pylops_mpi.Partition.BROADCAST, | 
|  | 151 | +                                 base_comm_nccl=nccl_comm, | 
|  | 152 | +                                 engine="cupy") | 
|  | 153 | +x0[:] = 0 | 
|  | 154 | +minv_dist = pylops_mpi.cgls(VStack, d_dist, x0=x0, niter=100, show=True)[0] | 
|  | 155 | +minv = minv_dist.asarray().reshape((nx, nz)) | 
|  | 156 | +d_inv_dist = VStack @ minv_dist | 
|  | 157 | +d_inv = d_inv_dist.asarray().reshape(nstot, nr, nt) | 
|  | 158 | + | 
|  | 159 | +############################################################################## | 
|  | 160 | +# Finally we visualize the results. Note that the array must be copied back | 
|  | 161 | +# to the CPU by calling the :code:`get()` method on the CuPy arrays. | 
|  | 162 | + | 
|  | 163 | +if rank == 0: | 
|  | 164 | +    # Visualize | 
|  | 165 | +    fig1, axs = plt.subplots(1, 3, figsize=(10, 3)) | 
|  | 166 | +    axs[0].imshow(refl.T, cmap="gray", vmin=-1, vmax=1) | 
|  | 167 | +    axs[0].axis("tight") | 
|  | 168 | +    axs[0].set_title(r"$m$") | 
|  | 169 | +    axs[1].imshow(madj.T.get(), cmap="gray", vmin=-madj.max(), vmax=madj.max()) | 
|  | 170 | +    axs[1].set_title(r"$m_{adj}$") | 
|  | 171 | +    axs[1].axis("tight") | 
|  | 172 | +    axs[2].imshow(minv.T.get(), cmap="gray", vmin=-1, vmax=1) | 
|  | 173 | +    axs[2].axis("tight") | 
|  | 174 | +    axs[2].set_title(r"$m_{inv}$") | 
|  | 175 | +    plt.tight_layout() | 
|  | 176 | +    fig1.savefig("model.png") | 
|  | 177 | + | 
|  | 178 | +    fig2, axs = plt.subplots(1, 3, figsize=(10, 3)) | 
|  | 179 | +    axs[0].imshow(d[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max()) | 
|  | 180 | +    axs[0].set_title(r"$d$") | 
|  | 181 | +    axs[0].axis("tight") | 
|  | 182 | +    axs[1].imshow(d_adj[0, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max()) | 
|  | 183 | +    axs[1].set_title(r"$d_{adj}$") | 
|  | 184 | +    axs[1].axis("tight") | 
|  | 185 | +    axs[2].imshow(d_inv[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max()) | 
|  | 186 | +    axs[2].set_title(r"$d_{inv}$") | 
|  | 187 | +    axs[2].axis("tight") | 
|  | 188 | +    fig2.savefig("data1.png") | 
|  | 189 | + | 
|  | 190 | +    fig3, axs = plt.subplots(1, 3, figsize=(10, 3)) | 
|  | 191 | +    axs[0].imshow(d[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max()) | 
|  | 192 | +    axs[0].set_title(r"$d$") | 
|  | 193 | +    axs[0].axis("tight") | 
|  | 194 | +    axs[1].imshow(d_adj[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max()) | 
|  | 195 | +    axs[1].set_title(r"$d_{adj}$") | 
|  | 196 | +    axs[1].axis("tight") | 
|  | 197 | +    axs[2].imshow(d_inv[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max()) | 
|  | 198 | +    axs[2].set_title(r"$d_{inv}$") | 
|  | 199 | +    axs[2].axis("tight") | 
|  | 200 | +    plt.tight_layout() | 
|  | 201 | +    fig3.savefig("data2.png") | 
0 commit comments