Skip to content

Commit b7cdbab

Browse files
committed
test: added cuda tests for Kirchhoff
1 parent b944dcb commit b7cdbab

File tree

1 file changed

+82
-46
lines changed

1 file changed

+82
-46
lines changed

pytests/test_kirchhoff.py

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import os
22

33
import numpy as np
4+
5+
if int(os.environ.get("TEST_CUPY_PYLOPS", 0)):
6+
import cupy as np
7+
from cupy.testing import assert_array_almost_equal
8+
9+
backend = "cuda"
10+
else:
11+
import numpy as np
12+
from numpy.testing import assert_array_almost_equal
13+
14+
backend = "numpy"
15+
import numpy as npp
416
import pytest
5-
from numpy.testing import assert_array_almost_equal
617

718
from pylops.utils import dottest
819
from pylops.utils.wavelets import ricker
@@ -34,22 +45,22 @@
3445
skfmm_enabled = False
3546

3647
v0 = 500
37-
y = np.arange(PAR["ny"]) * PAR["dy"]
38-
x = np.arange(PAR["nx"]) * PAR["dx"]
39-
z = np.arange(PAR["nz"]) * PAR["dz"]
40-
t = np.arange(PAR["nt"]) * PAR["dt"]
41-
42-
sy = np.linspace(y.min(), y.max(), PAR["nsy"])
43-
sx = np.linspace(x.min(), x.max(), PAR["nsx"])
44-
syy, sxx = np.meshgrid(sy, sx, indexing="ij")
45-
s2d = np.vstack((sx, 2 * np.ones(PAR["nsx"])))
46-
s3d = np.vstack((syy.ravel(), sxx.ravel(), 2 * np.ones(PAR["nsx"] * PAR["nsy"])))
47-
48-
ry = np.linspace(y.min(), y.max(), PAR["nry"])
49-
rx = np.linspace(x.min(), x.max(), PAR["nrx"])
50-
ryy, rxx = np.meshgrid(ry, rx, indexing="ij")
51-
r2d = np.vstack((rx, 2 * np.ones(PAR["nrx"])))
52-
r3d = np.vstack((ryy.ravel(), rxx.ravel(), 2 * np.ones(PAR["nrx"] * PAR["nry"])))
48+
y = npp.arange(PAR["ny"]) * PAR["dy"]
49+
x = npp.arange(PAR["nx"]) * PAR["dx"]
50+
z = npp.arange(PAR["nz"]) * PAR["dz"]
51+
t = npp.arange(PAR["nt"]) * PAR["dt"]
52+
53+
sy = npp.linspace(y.min(), y.max(), PAR["nsy"])
54+
sx = npp.linspace(x.min(), x.max(), PAR["nsx"])
55+
syy, sxx = npp.meshgrid(sy, sx, indexing="ij")
56+
s2d = npp.vstack((sx, 2 * npp.ones(PAR["nsx"])))
57+
s3d = npp.vstack((syy.ravel(), sxx.ravel(), 2 * npp.ones(PAR["nsx"] * PAR["nsy"])))
58+
59+
ry = npp.linspace(y.min(), y.max(), PAR["nry"])
60+
rx = npp.linspace(x.min(), x.max(), PAR["nrx"])
61+
ryy, rxx = npp.meshgrid(ry, rx, indexing="ij")
62+
r2d = npp.vstack((rx, 2 * npp.ones(PAR["nrx"])))
63+
r3d = npp.vstack((ryy.ravel(), rxx.ravel(), 2 * npp.ones(PAR["nrx"] * PAR["nry"])))
5364

5465
wav, _, wavc = ricker(t[:21], f0=40)
5566

@@ -157,7 +168,7 @@ def test_traveltime_table():
157168
trav_ana,
158169
trav_srcs_ana,
159170
trav_recs_ana,
160-
dist_ana,
171+
_,
161172
_,
162173
_,
163174
) = Kirchhoff._traveltime_table(z, x, s2d, r2d, v0, mode="analytic")
@@ -166,7 +177,7 @@ def test_traveltime_table():
166177
trav_eik,
167178
trav_srcs_eik,
168179
trav_recs_eik,
169-
dist_eik,
180+
_,
170181
_,
171182
_,
172183
) = Kirchhoff._traveltime_table(
@@ -181,20 +192,13 @@ def test_traveltime_table():
181192
(
182193
trav_srcs_ana,
183194
trav_recs_ana,
184-
dist_srcs_ana,
185-
dist_recs_ana,
186195
_,
187196
_,
188-
) = Kirchhoff._traveltime_table(z, x, s3d, r3d, v0, y=y, mode="analytic")
189-
190-
(
191-
trav_srcs_eik,
192-
trav_recs_eik,
193-
dist_srcs_eik,
194-
dist_recs_eik,
195197
_,
196198
_,
197-
) = Kirchhoff._traveltime_table(
199+
) = Kirchhoff._traveltime_table(z, x, s3d, r3d, v0, y=y, mode="analytic")
200+
201+
(trav_srcs_eik, trav_recs_eik, _, _, _, _,) = Kirchhoff._traveltime_table(
198202
z,
199203
x,
200204
s3d,
@@ -206,12 +210,8 @@ def test_traveltime_table():
206210

207211
assert_array_almost_equal(trav_srcs_ana, trav_srcs_eik, decimal=2)
208212
assert_array_almost_equal(trav_recs_ana, trav_recs_eik, decimal=2)
209-
assert_array_almost_equal(trav_ana, trav_eik, decimal=2)
210213

211214

212-
@pytest.mark.skipif(
213-
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
214-
)
215215
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par1d), (par2d), (par3d)])
216216
def test_kirchhoff2d(par):
217217
"""Dot-test for Kirchhoff operator"""
@@ -226,8 +226,6 @@ def test_kirchhoff2d(par):
226226
) + trav_recs.reshape(PAR["nx"] * PAR["nz"], 1, PAR["nrx"])
227227
trav = trav.reshape(PAR["nx"] * PAR["nz"], PAR["nsx"] * PAR["nrx"])
228228
amp = None
229-
# amp = 1 / (dist + 1e-2 * dist.max())
230-
231229
else:
232230
trav = None
233231
amp = None
@@ -240,19 +238,31 @@ def test_kirchhoff2d(par):
240238
s2d,
241239
r2d,
242240
vel if par["mode"] == "eikonal" else v0,
243-
wav,
241+
np.asarray(wav),
244242
wavc,
245243
y=None,
246244
trav=trav,
247245
amp=amp,
248246
mode=par["mode"],
247+
engine=backend,
248+
)
249+
if par["mode"] == "byot":
250+
Dop.trav = np.asarray(Dop.trav)
251+
else:
252+
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
253+
Dop.trav_recs = np.asarray(Dop.trav_recs)
254+
if par["mode"] == "dynamic":
255+
Dop.amp_srcs = np.asarray(Dop.amp_srcs)
256+
Dop.amp_recs = np.asarray(Dop.amp_recs)
257+
258+
assert dottest(
259+
Dop,
260+
PAR["nsx"] * PAR["nrx"] * PAR["nt"],
261+
PAR["nz"] * PAR["nx"],
262+
backend=backend,
249263
)
250-
assert dottest(Dop, PAR["nsx"] * PAR["nrx"] * PAR["nt"], PAR["nz"] * PAR["nx"])
251264

252265

253-
@pytest.mark.skipif(
254-
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
255-
)
256266
@pytest.mark.parametrize("par", [(par1), (par2), (par3)])
257267
def test_kirchhoff3d(par):
258268
"""Dot-test for Kirchhoff operator"""
@@ -287,17 +297,22 @@ def test_kirchhoff3d(par):
287297
y=y,
288298
trav=trav,
289299
mode=par["mode"],
300+
engine=backend,
290301
)
302+
if par["mode"] == "byot":
303+
Dop.trav = np.asarray(Dop.trav)
304+
else:
305+
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
306+
Dop.trav_recs = np.asarray(Dop.trav_recs)
307+
291308
assert dottest(
292309
Dop,
293310
PAR["nsx"] * PAR["nrx"] * PAR["nsy"] * PAR["nry"] * PAR["nt"],
294311
PAR["nz"] * PAR["nx"] * PAR["ny"],
312+
backend=backend,
295313
)
296314

297315

298-
@pytest.mark.skipif(
299-
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
300-
)
301316
@pytest.mark.parametrize(
302317
"par",
303318
[
@@ -323,7 +338,13 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
323338
mode=par["mode"],
324339
dynamic=par["dynamic"],
325340
angleaperture=None,
341+
engine=backend,
326342
)
343+
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
344+
Dop.trav_recs = np.asarray(Dop.trav_recs)
345+
if par["dynamic"]:
346+
Dop.amp_srcs = np.asarray(Dop.amp_srcs)
347+
Dop.amp_recs = np.asarray(Dop.amp_recs)
327348

328349
# old behaviour
329350
trav = Dop.trav_srcs.reshape(
@@ -348,7 +369,13 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
348369
mode=par["mode"],
349370
dynamic=par["dynamic"],
350371
angleaperture=None,
372+
engine=backend,
351373
)
374+
D1op.trav_srcs = np.asarray(D1op.trav_srcs)
375+
D1op.trav_recs = np.asarray(D1op.trav_recs)
376+
if par["dynamic"]:
377+
D1op.amp_srcs = np.asarray(D1op.amp_srcs)
378+
D1op.amp_recs = np.asarray(D1op.amp_recs)
352379

353380
# forward
354381
xx = np.random.normal(0, 1, PAR["nx"] * PAR["nz"])
@@ -359,9 +386,6 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
359386
assert_array_almost_equal(Dop.H @ yy, D1op.H @ yy, decimal=2)
360387

361388

362-
@pytest.mark.skipif(
363-
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
364-
)
365389
@pytest.mark.parametrize(
366390
"par",
367391
[
@@ -384,7 +408,13 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
384408
wavc,
385409
y=y,
386410
mode=par["mode"],
411+
engine=backend,
387412
)
413+
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
414+
Dop.trav_recs = np.asarray(Dop.trav_recs)
415+
if par["dynamic"]:
416+
Dop.amp_srcs = np.asarray(Dop.amp_srcs)
417+
Dop.amp_recs = np.asarray(Dop.amp_recs)
388418

389419
# old behaviour
390420
trav = Dop.trav_srcs.reshape(
@@ -409,7 +439,13 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
409439
y=y,
410440
trav=trav,
411441
mode=par["mode"],
442+
engine=backend,
412443
)
444+
D1op.trav_srcs = np.asarray(D1op.trav_srcs)
445+
D1op.trav_recs = np.asarray(D1op.trav_recs)
446+
if par["dynamic"]:
447+
D1op.amp_srcs = np.asarray(D1op.amp_srcs)
448+
D1op.amp_recs = np.asarray(D1op.amp_recs)
413449

414450
# forward
415451
xx = np.random.normal(0, 1, PAR["ny"] * PAR["nx"] * PAR["nz"])

0 commit comments

Comments
 (0)