Skip to content

Commit 7c31951

Browse files
committed
kirchhoff example for benchmark decorator
1 parent 626006f commit 7c31951

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import warnings
2+
warnings.filterwarnings('ignore')
3+
4+
import numpy as np
5+
import cupy as cp
6+
from pylops_mpi.utils.benchmark import benchmark, mark
7+
from mpi4py import MPI
8+
9+
from pylops.utils.wavelets import ricker
10+
from pylops.waveeqprocessing.lsm import LSM
11+
12+
import pylops_mpi
13+
np.random.seed(42)
14+
rank = MPI.COMM_WORLD.Get_rank()
15+
16+
# Trade-off: model size is controlled by (nx, nz) which affects data transfer and
17+
# (nr, ns) controls data size which affects the compute
18+
par1 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": False, "use_nccl": False}
19+
par2 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": True, "use_nccl": False}
20+
par3 = {"nx": 30, "nz": 50, "ns": 9, "nr": 18, "use_cupy": True, "use_nccl": True}
21+
22+
23+
def prepare_kirchhoff_op(par):
24+
v0 = 1500
25+
dx = 12.5
26+
dz = 4
27+
x, z = np.arange(par["nx"]) * dx, np.arange(par["nz"]) * dz
28+
29+
# recv and source config
30+
rx = np.linspace(10 * dx, (par["nx"] - 10) * dx, par["nr"])
31+
rz = 20 * np.ones(par["nr"])
32+
recs = np.vstack((rx, rz))
33+
34+
nstot = MPI.COMM_WORLD.allreduce(par["ns"], op=MPI.SUM)
35+
sxtot = np.linspace(dx * 10, (par["nx"] - 10) * dx, nstot)
36+
sx = sxtot[rank * par["ns"]: (rank + 1) * par["ns"]]
37+
sz = 10 * np.ones(par["ns"])
38+
sources = np.vstack((sx, sz))
39+
40+
# Wavelet
41+
nt = 651
42+
dt = 0.004
43+
t = np.arange(nt) * dt
44+
wav, wavt, wavc = ricker(t[:41], f0=20)
45+
46+
lsm_op = LSM(
47+
z,
48+
x,
49+
t,
50+
sources,
51+
recs,
52+
v0,
53+
cp.asarray(wav.astype(np.float32)) if par["use_cupy"] else wav,
54+
wavc,
55+
mode="analytic",
56+
engine="cuda" if par["use_cupy"] else "numba",
57+
dtype=np.float32
58+
)
59+
if par["use_cupy"]:
60+
lsm_op.Demop.trav_srcs = cp.asarray(lsm_op.Demop.trav_srcs.astype(np.float32))
61+
lsm_op.Demop.trav_recs = cp.asarray(lsm_op.Demop.trav_recs.astype(np.float32))
62+
63+
return lsm_op
64+
65+
66+
def prepare_distributed_data(par, lsm_op, nccl_comm):
67+
# Reflectivity Model
68+
refl = np.zeros((par["nx"], par["nz"]))
69+
refl[:, par["nz"] // 4] = -1
70+
refl[:, par["nz"] // 2] = 0.5
71+
refl_dist = pylops_mpi.DistributedArray(global_shape=par["nx"] * par["nz"],
72+
partition=pylops_mpi.Partition.BROADCAST,
73+
base_comm_nccl=nccl_comm,
74+
engine="cupy" if par["use_cupy"] else "numpy")
75+
refl_dist[:] = cp.asarray(refl.flatten()) if par["use_cupy"] else refl.flatten()
76+
77+
VStack = pylops_mpi.MPIVStack(ops=[lsm_op.Demop, ])
78+
d_dist = VStack @ refl_dist
79+
return d_dist
80+
81+
82+
@benchmark
83+
def run_bench(par):
84+
# if run with MPI, NCCL should not be initialized at all to avoid hang
85+
if par["use_nccl"]:
86+
nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm()
87+
else:
88+
nccl_comm = None
89+
90+
mark(f"begin {par["use_cupy"]=}, {par["use_nccl"]=}")
91+
lsm_op = prepare_kirchhoff_op(par)
92+
d_dist = prepare_distributed_data(par, lsm_op, nccl_comm)
93+
VStack = pylops_mpi.MPIVStack(ops=[lsm_op.Demop, ])
94+
mark("after prepare")
95+
# TODO (tharitt): In the actual benchmark, we probably have to decorate
96+
# the matvec() and rmatvec() to separate computation from communication time
97+
madj_dist = VStack.H @ d_dist
98+
mark("after adjoint")
99+
_ = madj_dist.asarray().reshape((par["nx"], par["nz"]))
100+
101+
102+
if __name__ == "__main__":
103+
run_bench(par1)
104+
print("========")
105+
run_bench(par2)
106+
print("========")
107+
run_bench(par3)

pylops_mpi/utils/benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
# TODO (tharitt): later move to env file or something
55
ENABLE_BENCHMARK = True
66

7+
78
# This function is to be instrumented throughout the targeted function
89
def mark(label):
910
if _current_mark_func is not None:
1011
_current_mark_func(label)
1112

13+
1214
# Global hook - this will be re-assigned (points to)
1315
# the function defined in benchmark wrapper
1416
_current_mark_func = None
1517

18+
1619
def benchmark(func):
1720
"""A wrapper for code injection for time measurement.
1821
1922
This wrapper allows users to put a call to mark()
2023
anywhere inside the wrapped function. The function mark()
2124
is defined in the global scope to be a placeholder for the targeted
2225
function to import. This wrapper will make it points to local_mark() defined
23-
in this function. Therefore, the wrapped function will be able call
26+
in this function. Therefore, the wrapped function will be able call
2427
local_mark(). All the context for local_mark() like mark list can be
2528
hidden from users and thus provide clean interface.
2629
@@ -50,7 +53,7 @@ def local_mark(label):
5053
# clean up to original state
5154
_current_mark_func = None
5255

53-
# TODO (tharitt): maybe changing to saving results to file instead
56+
# TODO (tharitt): maybe changing to saving results to file instead
5457
if marks:
5558
prev_label, prev_t = marks[0]
5659
print(f"[BENCH] {prev_label}: 0.000000s")

0 commit comments

Comments
 (0)