Skip to content

Commit b666567

Browse files
committed
test: fix test_kirchhoff for cuda engine
1 parent b7cdbab commit b666567

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pytests/test_kirchhoff.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import cupy as np
77
from cupy.testing import assert_array_almost_equal
88

9-
backend = "cuda"
9+
backend = "cupy"
1010
else:
1111
import numpy as np
1212
from numpy.testing import assert_array_almost_equal
@@ -215,6 +215,8 @@ def test_traveltime_table():
215215
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par1d), (par2d), (par3d)])
216216
def test_kirchhoff2d(par):
217217
"""Dot-test for Kirchhoff operator"""
218+
if backend == "cupy" and par["mode"] == "byot":
219+
pytest.skip("cuda engine not available for single trav table")
218220
vel = v0 * np.ones((PAR["nx"], PAR["nz"]))
219221

220222
if par["mode"] == "byot":
@@ -244,7 +246,7 @@ def test_kirchhoff2d(par):
244246
trav=trav,
245247
amp=amp,
246248
mode=par["mode"],
247-
engine=backend,
249+
engine="numpy" if backend == "numpy" else "cuda",
248250
)
249251
if par["mode"] == "byot":
250252
Dop.trav = np.asarray(Dop.trav)
@@ -266,6 +268,8 @@ def test_kirchhoff2d(par):
266268
@pytest.mark.parametrize("par", [(par1), (par2), (par3)])
267269
def test_kirchhoff3d(par):
268270
"""Dot-test for Kirchhoff operator"""
271+
if backend == "cupy" and par["mode"] == "byot":
272+
pytest.skip("cuda engine not available for single trav table")
269273
vel = v0 * np.ones((PAR["ny"], PAR["nx"], PAR["nz"]))
270274

271275
if par["mode"] == "byot":
@@ -297,7 +301,7 @@ def test_kirchhoff3d(par):
297301
y=y,
298302
trav=trav,
299303
mode=par["mode"],
300-
engine=backend,
304+
engine="numpy" if backend == "numpy" else "cuda",
301305
)
302306
if par["mode"] == "byot":
303307
Dop.trav = np.asarray(Dop.trav)
@@ -338,7 +342,7 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
338342
mode=par["mode"],
339343
dynamic=par["dynamic"],
340344
angleaperture=None,
341-
engine=backend,
345+
engine="numpy" if backend == "numpy" else "cuda",
342346
)
343347
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
344348
Dop.trav_recs = np.asarray(Dop.trav_recs)
@@ -369,7 +373,7 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
369373
mode=par["mode"],
370374
dynamic=par["dynamic"],
371375
angleaperture=None,
372-
engine=backend,
376+
engine="numpy" if backend == "numpy" else "cuda",
373377
)
374378
D1op.trav_srcs = np.asarray(D1op.trav_srcs)
375379
D1op.trav_recs = np.asarray(D1op.trav_recs)
@@ -408,7 +412,7 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
408412
wavc,
409413
y=y,
410414
mode=par["mode"],
411-
engine=backend,
415+
engine="numpy" if backend == "numpy" else "cuda",
412416
)
413417
Dop.trav_srcs = np.asarray(Dop.trav_srcs)
414418
Dop.trav_recs = np.asarray(Dop.trav_recs)
@@ -439,7 +443,7 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
439443
y=y,
440444
trav=trav,
441445
mode=par["mode"],
442-
engine=backend,
446+
engine="numpy" if backend == "numpy" else "cuda",
443447
)
444448
D1op.trav_srcs = np.asarray(D1op.trav_srcs)
445449
D1op.trav_recs = np.asarray(D1op.trav_recs)

0 commit comments

Comments
 (0)