11import os
22
33import numpy as np
4+
5+ if int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )):
6+ import cupy as np
7+ from cupy .testing import assert_array_almost_equal
8+
9+ backend = "cuda"
10+ else :
11+ import numpy as np
12+ from numpy .testing import assert_array_almost_equal
13+
14+ backend = "numpy"
15+ import numpy as npp
416import pytest
5- from numpy .testing import assert_array_almost_equal
617
718from pylops .utils import dottest
819from pylops .utils .wavelets import ricker
3445 skfmm_enabled = False
3546
3647v0 = 500
37- y = np .arange (PAR ["ny" ]) * PAR ["dy" ]
38- x = np .arange (PAR ["nx" ]) * PAR ["dx" ]
39- z = np .arange (PAR ["nz" ]) * PAR ["dz" ]
40- t = np .arange (PAR ["nt" ]) * PAR ["dt" ]
41-
42- sy = np .linspace (y .min (), y .max (), PAR ["nsy" ])
43- sx = np .linspace (x .min (), x .max (), PAR ["nsx" ])
44- syy , sxx = np .meshgrid (sy , sx , indexing = "ij" )
45- s2d = np .vstack ((sx , 2 * np .ones (PAR ["nsx" ])))
46- s3d = np .vstack ((syy .ravel (), sxx .ravel (), 2 * np .ones (PAR ["nsx" ] * PAR ["nsy" ])))
47-
48- ry = np .linspace (y .min (), y .max (), PAR ["nry" ])
49- rx = np .linspace (x .min (), x .max (), PAR ["nrx" ])
50- ryy , rxx = np .meshgrid (ry , rx , indexing = "ij" )
51- r2d = np .vstack ((rx , 2 * np .ones (PAR ["nrx" ])))
52- r3d = np .vstack ((ryy .ravel (), rxx .ravel (), 2 * np .ones (PAR ["nrx" ] * PAR ["nry" ])))
48+ y = npp .arange (PAR ["ny" ]) * PAR ["dy" ]
49+ x = npp .arange (PAR ["nx" ]) * PAR ["dx" ]
50+ z = npp .arange (PAR ["nz" ]) * PAR ["dz" ]
51+ t = npp .arange (PAR ["nt" ]) * PAR ["dt" ]
52+
53+ sy = npp .linspace (y .min (), y .max (), PAR ["nsy" ])
54+ sx = npp .linspace (x .min (), x .max (), PAR ["nsx" ])
55+ syy , sxx = npp .meshgrid (sy , sx , indexing = "ij" )
56+ s2d = npp .vstack ((sx , 2 * npp .ones (PAR ["nsx" ])))
57+ s3d = npp .vstack ((syy .ravel (), sxx .ravel (), 2 * npp .ones (PAR ["nsx" ] * PAR ["nsy" ])))
58+
59+ ry = npp .linspace (y .min (), y .max (), PAR ["nry" ])
60+ rx = npp .linspace (x .min (), x .max (), PAR ["nrx" ])
61+ ryy , rxx = npp .meshgrid (ry , rx , indexing = "ij" )
62+ r2d = npp .vstack ((rx , 2 * npp .ones (PAR ["nrx" ])))
63+ r3d = npp .vstack ((ryy .ravel (), rxx .ravel (), 2 * npp .ones (PAR ["nrx" ] * PAR ["nry" ])))
5364
5465wav , _ , wavc = ricker (t [:21 ], f0 = 40 )
5566
@@ -157,7 +168,7 @@ def test_traveltime_table():
157168 trav_ana ,
158169 trav_srcs_ana ,
159170 trav_recs_ana ,
160- dist_ana ,
171+ _ ,
161172 _ ,
162173 _ ,
163174 ) = Kirchhoff ._traveltime_table (z , x , s2d , r2d , v0 , mode = "analytic" )
@@ -166,7 +177,7 @@ def test_traveltime_table():
166177 trav_eik ,
167178 trav_srcs_eik ,
168179 trav_recs_eik ,
169- dist_eik ,
180+ _ ,
170181 _ ,
171182 _ ,
172183 ) = Kirchhoff ._traveltime_table (
@@ -181,20 +192,13 @@ def test_traveltime_table():
181192 (
182193 trav_srcs_ana ,
183194 trav_recs_ana ,
184- dist_srcs_ana ,
185- dist_recs_ana ,
186195 _ ,
187196 _ ,
188- ) = Kirchhoff ._traveltime_table (z , x , s3d , r3d , v0 , y = y , mode = "analytic" )
189-
190- (
191- trav_srcs_eik ,
192- trav_recs_eik ,
193- dist_srcs_eik ,
194- dist_recs_eik ,
195197 _ ,
196198 _ ,
197- ) = Kirchhoff ._traveltime_table (
199+ ) = Kirchhoff ._traveltime_table (z , x , s3d , r3d , v0 , y = y , mode = "analytic" )
200+
201+ (trav_srcs_eik , trav_recs_eik , _ , _ , _ , _ ,) = Kirchhoff ._traveltime_table (
198202 z ,
199203 x ,
200204 s3d ,
@@ -206,12 +210,8 @@ def test_traveltime_table():
206210
207211 assert_array_almost_equal (trav_srcs_ana , trav_srcs_eik , decimal = 2 )
208212 assert_array_almost_equal (trav_recs_ana , trav_recs_eik , decimal = 2 )
209- assert_array_almost_equal (trav_ana , trav_eik , decimal = 2 )
210213
211214
212- @pytest .mark .skipif (
213- int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )) == 1 , reason = "Not CuPy enabled"
214- )
215215@pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par3 ), (par1d ), (par2d ), (par3d )])
216216def test_kirchhoff2d (par ):
217217 """Dot-test for Kirchhoff operator"""
@@ -226,8 +226,6 @@ def test_kirchhoff2d(par):
226226 ) + trav_recs .reshape (PAR ["nx" ] * PAR ["nz" ], 1 , PAR ["nrx" ])
227227 trav = trav .reshape (PAR ["nx" ] * PAR ["nz" ], PAR ["nsx" ] * PAR ["nrx" ])
228228 amp = None
229- # amp = 1 / (dist + 1e-2 * dist.max())
230-
231229 else :
232230 trav = None
233231 amp = None
@@ -240,19 +238,31 @@ def test_kirchhoff2d(par):
240238 s2d ,
241239 r2d ,
242240 vel if par ["mode" ] == "eikonal" else v0 ,
243- wav ,
241+ np . asarray ( wav ) ,
244242 wavc ,
245243 y = None ,
246244 trav = trav ,
247245 amp = amp ,
248246 mode = par ["mode" ],
247+ engine = backend ,
248+ )
249+ if par ["mode" ] == "byot" :
250+ Dop .trav = np .asarray (Dop .trav )
251+ else :
252+ Dop .trav_srcs = np .asarray (Dop .trav_srcs )
253+ Dop .trav_recs = np .asarray (Dop .trav_recs )
254+ if par ["mode" ] == "dynamic" :
255+ Dop .amp_srcs = np .asarray (Dop .amp_srcs )
256+ Dop .amp_recs = np .asarray (Dop .amp_recs )
257+
258+ assert dottest (
259+ Dop ,
260+ PAR ["nsx" ] * PAR ["nrx" ] * PAR ["nt" ],
261+ PAR ["nz" ] * PAR ["nx" ],
262+ backend = backend ,
249263 )
250- assert dottest (Dop , PAR ["nsx" ] * PAR ["nrx" ] * PAR ["nt" ], PAR ["nz" ] * PAR ["nx" ])
251264
252265
253- @pytest .mark .skipif (
254- int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )) == 1 , reason = "Not CuPy enabled"
255- )
256266@pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par3 )])
257267def test_kirchhoff3d (par ):
258268 """Dot-test for Kirchhoff operator"""
@@ -287,17 +297,22 @@ def test_kirchhoff3d(par):
287297 y = y ,
288298 trav = trav ,
289299 mode = par ["mode" ],
300+ engine = backend ,
290301 )
302+ if par ["mode" ] == "byot" :
303+ Dop .trav = np .asarray (Dop .trav )
304+ else :
305+ Dop .trav_srcs = np .asarray (Dop .trav_srcs )
306+ Dop .trav_recs = np .asarray (Dop .trav_recs )
307+
291308 assert dottest (
292309 Dop ,
293310 PAR ["nsx" ] * PAR ["nrx" ] * PAR ["nsy" ] * PAR ["nry" ] * PAR ["nt" ],
294311 PAR ["nz" ] * PAR ["nx" ] * PAR ["ny" ],
312+ backend = backend ,
295313 )
296314
297315
298- @pytest .mark .skipif (
299- int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )) == 1 , reason = "Not CuPy enabled"
300- )
301316@pytest .mark .parametrize (
302317 "par" ,
303318 [
@@ -323,7 +338,13 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
323338 mode = par ["mode" ],
324339 dynamic = par ["dynamic" ],
325340 angleaperture = None ,
341+ engine = backend ,
326342 )
343+ Dop .trav_srcs = np .asarray (Dop .trav_srcs )
344+ Dop .trav_recs = np .asarray (Dop .trav_recs )
345+ if par ["dynamic" ]:
346+ Dop .amp_srcs = np .asarray (Dop .amp_srcs )
347+ Dop .amp_recs = np .asarray (Dop .amp_recs )
327348
328349 # old behaviour
329350 trav = Dop .trav_srcs .reshape (
@@ -348,7 +369,13 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
348369 mode = par ["mode" ],
349370 dynamic = par ["dynamic" ],
350371 angleaperture = None ,
372+ engine = backend ,
351373 )
374+ D1op .trav_srcs = np .asarray (D1op .trav_srcs )
375+ D1op .trav_recs = np .asarray (D1op .trav_recs )
376+ if par ["dynamic" ]:
377+ D1op .amp_srcs = np .asarray (D1op .amp_srcs )
378+ D1op .amp_recs = np .asarray (D1op .amp_recs )
352379
353380 # forward
354381 xx = np .random .normal (0 , 1 , PAR ["nx" ] * PAR ["nz" ])
@@ -359,9 +386,6 @@ def test_kirchhoff2d_trav_vs_travsrcrec(par):
359386 assert_array_almost_equal (Dop .H @ yy , D1op .H @ yy , decimal = 2 )
360387
361388
362- @pytest .mark .skipif (
363- int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )) == 1 , reason = "Not CuPy enabled"
364- )
365389@pytest .mark .parametrize (
366390 "par" ,
367391 [
@@ -384,7 +408,13 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
384408 wavc ,
385409 y = y ,
386410 mode = par ["mode" ],
411+ engine = backend ,
387412 )
413+ Dop .trav_srcs = np .asarray (Dop .trav_srcs )
414+ Dop .trav_recs = np .asarray (Dop .trav_recs )
415+ if par ["dynamic" ]:
416+ Dop .amp_srcs = np .asarray (Dop .amp_srcs )
417+ Dop .amp_recs = np .asarray (Dop .amp_recs )
388418
389419 # old behaviour
390420 trav = Dop .trav_srcs .reshape (
@@ -409,7 +439,13 @@ def test_kirchhoff3d_trav_vs_travsrcrec(par):
409439 y = y ,
410440 trav = trav ,
411441 mode = par ["mode" ],
442+ engine = backend ,
412443 )
444+ D1op .trav_srcs = np .asarray (D1op .trav_srcs )
445+ D1op .trav_recs = np .asarray (D1op .trav_recs )
446+ if par ["dynamic" ]:
447+ D1op .amp_srcs = np .asarray (D1op .amp_srcs )
448+ D1op .amp_recs = np .asarray (D1op .amp_recs )
413449
414450 # forward
415451 xx = np .random .normal (0 , 1 , PAR ["ny" ] * PAR ["nx" ] * PAR ["nz" ])
0 commit comments