|  | 
|  | 1 | +import warnings | 
|  | 2 | +warnings.filterwarnings('ignore') | 
|  | 3 | + | 
|  | 4 | +import numpy as np | 
|  | 5 | +import cupy as cp | 
|  | 6 | +from pylops_mpi.utils.benchmark import benchmark, mark | 
|  | 7 | +from mpi4py import MPI | 
|  | 8 | + | 
|  | 9 | +from pylops.utils.wavelets import ricker | 
|  | 10 | +from pylops.waveeqprocessing.lsm import LSM | 
|  | 11 | + | 
|  | 12 | +import pylops_mpi | 
|  | 13 | +np.random.seed(42) | 
|  | 14 | +rank = MPI.COMM_WORLD.Get_rank() | 
|  | 15 | + | 
|  | 16 | +# Trade-off: model size is controlled by (nx, nz) which affects data transfer and | 
|  | 17 | +# (nr, ns) controls data size which affects the compute | 
|  | 18 | +par1 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": False, "use_nccl": False} | 
|  | 19 | +par2 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": True, "use_nccl": False} | 
|  | 20 | +par3 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": True, "use_nccl": True} | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +def prepare_kirchhoff_op(par): | 
|  | 24 | +    v0 = 1500 | 
|  | 25 | +    dx = 12.5 | 
|  | 26 | +    dz = 4 | 
|  | 27 | +    x, z = np.arange(par["nx"]) * dx, np.arange(par["nz"]) * dz | 
|  | 28 | + | 
|  | 29 | +    # recv and source config | 
|  | 30 | +    rx = np.linspace(10 * dx, (par["nx"] - 10) * dx, par["nr"]) | 
|  | 31 | +    rz = 20 * np.ones(par["nr"]) | 
|  | 32 | +    recs = np.vstack((rx, rz)) | 
|  | 33 | + | 
|  | 34 | +    nstot = MPI.COMM_WORLD.allreduce(par["ns"], op=MPI.SUM) | 
|  | 35 | +    sxtot = np.linspace(dx * 10, (par["nx"] - 10) * dx, nstot) | 
|  | 36 | +    sx = sxtot[rank * par["ns"]: (rank + 1) * par["ns"]] | 
|  | 37 | +    sz = 10 * np.ones(par["ns"]) | 
|  | 38 | +    sources = np.vstack((sx, sz)) | 
|  | 39 | + | 
|  | 40 | +    # Wavelet | 
|  | 41 | +    nt = 651 | 
|  | 42 | +    dt = 0.004 | 
|  | 43 | +    t = np.arange(nt) * dt | 
|  | 44 | +    wav, wavt, wavc = ricker(t[:41], f0=20) | 
|  | 45 | + | 
|  | 46 | +    lsm_op = LSM( | 
|  | 47 | +        z, | 
|  | 48 | +        x, | 
|  | 49 | +        t, | 
|  | 50 | +        sources, | 
|  | 51 | +        recs, | 
|  | 52 | +        v0, | 
|  | 53 | +        cp.asarray(wav.astype(np.float32)) if par["use_cupy"] else wav, | 
|  | 54 | +        wavc, | 
|  | 55 | +        mode="analytic", | 
|  | 56 | +        engine="cuda" if par["use_cupy"] else "numba", | 
|  | 57 | +        dtype=np.float32 | 
|  | 58 | +    ) | 
|  | 59 | +    if par["use_cupy"]: | 
|  | 60 | +        lsm_op.Demop.trav_srcs = cp.asarray(lsm_op.Demop.trav_srcs.astype(np.float32)) | 
|  | 61 | +        lsm_op.Demop.trav_recs = cp.asarray(lsm_op.Demop.trav_recs.astype(np.float32)) | 
|  | 62 | + | 
|  | 63 | +    return lsm_op | 
|  | 64 | + | 
|  | 65 | + | 
|  | 66 | +def prepare_distributed_data(par, lsm_op, nccl_comm): | 
|  | 67 | +    # Reflectivity Model | 
|  | 68 | +    refl = np.zeros((par["nx"], par["nz"])) | 
|  | 69 | +    refl[:, par["nz"] // 4] = -1 | 
|  | 70 | +    refl[:, par["nz"] // 2] = 0.5 | 
|  | 71 | +    refl_dist = pylops_mpi.DistributedArray(global_shape=par["nx"] * par["nz"], | 
|  | 72 | +                                            partition=pylops_mpi.Partition.BROADCAST, | 
|  | 73 | +                                            base_comm_nccl=nccl_comm, | 
|  | 74 | +                                            engine="cupy" if par["use_cupy"] else "numpy") | 
|  | 75 | +    refl_dist[:] = cp.asarray(refl.flatten()) if par["use_cupy"] else refl.flatten() | 
|  | 76 | + | 
|  | 77 | +    VStack = pylops_mpi.MPIVStack(ops=[lsm_op.Demop, ]) | 
|  | 78 | +    d_dist = VStack @ refl_dist | 
|  | 79 | +    return d_dist | 
|  | 80 | + | 
|  | 81 | + | 
|  | 82 | +@benchmark | 
|  | 83 | +def run_bench(par): | 
|  | 84 | +    # if run with MPI, NCCL should not be initialized at all to avoid hang | 
|  | 85 | +    if par["use_nccl"]: | 
|  | 86 | +        nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm() | 
|  | 87 | +    else: | 
|  | 88 | +        nccl_comm = None | 
|  | 89 | + | 
|  | 90 | +    mark(f"begin {par["use_cupy"]=}, {par["use_nccl"]=}") | 
|  | 91 | +    lsm_op = prepare_kirchhoff_op(par) | 
|  | 92 | +    d_dist = prepare_distributed_data(par, lsm_op, nccl_comm) | 
|  | 93 | +    VStack = pylops_mpi.MPIVStack(ops=[lsm_op.Demop, ]) | 
|  | 94 | +    mark("after prepare") | 
|  | 95 | +    # TODO (tharitt): In the actual benchmark, we probably have to decorate | 
|  | 96 | +    # the matvec() and rmatvec() to separate computation from communication time | 
|  | 97 | +    madj_dist = VStack.H @ d_dist | 
|  | 98 | +    mark("after adjoint") | 
|  | 99 | +    _ = madj_dist.asarray().reshape((par["nx"], par["nz"])) | 
|  | 100 | + | 
|  | 101 | + | 
|  | 102 | +if __name__ == "__main__": | 
|  | 103 | +    run_bench(par1) | 
|  | 104 | +    print("========") | 
|  | 105 | +    run_bench(par2) | 
|  | 106 | +    print("========") | 
|  | 107 | +    run_bench(par3) | 
0 commit comments