1+ """Test the derivative classes
2+ Designed to run with n processes
3+ $ mpiexec -n 10 pytest test_derivative.py --with-mpi
4+ """
15import numpy as np
26from mpi4py import MPI
37from numpy .testing import assert_allclose
48import pytest
59
610import pylops
711import pylops_mpi
12+ from pylops_mpi .utils .dottest import dottest
813
914np .random .seed (42 )
1015rank = MPI .COMM_WORLD .Get_rank ()
@@ -194,6 +199,8 @@ def test_first_derivative_forward(par):
194199 # Adjoint
195200 y_adj_dist = Fop_MPI .H @ x
196201 y_adj = y_adj_dist .asarray ()
202+ # Dot test
203+ dottest (Fop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
197204
198205 if rank == 0 :
199206 Fop = pylops .FirstDerivative (dims = par ['nz' ], axis = 0 ,
@@ -226,6 +233,8 @@ def test_first_derivative_backward(par):
226233 # Adjoint
227234 y_adj_dist = Fop_MPI .H @ x
228235 y_adj = y_adj_dist .asarray ()
236+ # Dot test
237+ dottest (Fop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
229238
230239 if rank == 0 :
231240 Fop = pylops .FirstDerivative (dims = par ['nz' ], axis = 0 ,
@@ -259,6 +268,9 @@ def test_first_derivative_centered(par):
259268 # Adjoint
260269 y_adj_dist = Fop_MPI .H @ x
261270 y_adj = y_adj_dist .asarray ()
271+ # Dot test
272+ dottest (Fop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
273+
262274 if rank == 0 :
263275 Fop = pylops .FirstDerivative (dims = par ['nz' ], axis = 0 ,
264276 sampling = par ['dz' ],
@@ -290,6 +302,8 @@ def test_second_derivative_forward(par):
290302 # Adjoint
291303 y_adj_dist = Sop_MPI .H @ x
292304 y_adj = y_adj_dist .asarray ()
305+ # Dot test
306+ dottest (Sop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
293307
294308 if rank == 0 :
295309 Sop = pylops .SecondDerivative (dims = par ['nz' ], axis = 0 ,
@@ -322,6 +336,8 @@ def test_second_derivative_backward(par):
322336 # Adjoint
323337 y_adj_dist = Sop_MPI .H @ x
324338 y_adj = y_adj_dist .asarray ()
339+ # Dot test
340+ dottest (Sop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
325341
326342 if rank == 0 :
327343 Sop = pylops .SecondDerivative (dims = par ['nz' ], axis = 0 ,
@@ -354,6 +370,8 @@ def test_second_derivative_centered(par):
354370 # Adjoint
355371 y_adj_dist = Sop_MPI .H @ x
356372 y_adj = y_adj_dist .asarray ()
373+ # Dot test
374+ dottest (Sop_MPI , x , y_dist , np .prod (par ['nz' ]), np .prod (par ['nz' ]))
357375
358376 if rank == 0 :
359377 Sop = pylops .SecondDerivative (dims = par ['nz' ], axis = 0 ,
@@ -385,6 +403,8 @@ def test_laplacian(par):
385403 # Adjoint
386404 y_adj_dist = Lop_MPI .H @ x
387405 y_adj = y_adj_dist .asarray ()
406+ # Dot test
407+ dottest (Lop_MPI , x , y_dist , np .prod (par ['n' ]), np .prod (par ['n' ]))
388408
389409 if rank == 0 :
390410 Lop = pylops .Laplacian (dims = par ['n' ], axes = par ['axes' ],
@@ -409,6 +429,7 @@ def test_gradient(par):
409429 x_fwd = pylops_mpi .DistributedArray (global_shape = np .prod (par ['n' ]), dtype = par ['dtype' ])
410430 x_fwd [:] = np .random .normal (rank , 10 , x_fwd .local_shape )
411431 x_global = x_fwd .asarray ()
432+
412433 # Forward
413434 y_dist = Gop_MPI @ x_fwd
414435 assert isinstance (y_dist , pylops_mpi .StackedDistributedArray )
@@ -421,15 +442,15 @@ def test_gradient(par):
421442 x_adj_dist2 [:] = np .random .normal (rank , 20 , x_adj_dist2 .local_shape )
422443 x_adj_dist3 = pylops_mpi .DistributedArray (global_shape = int (np .prod (par ['n' ])), dtype = par ['dtype' ])
423444 x_adj_dist3 [:] = np .random .normal (rank , 30 , x_adj_dist3 .local_shape )
424-
425445 x_adj = pylops_mpi .StackedDistributedArray (distarrays = [x_adj_dist1 , x_adj_dist2 , x_adj_dist3 ])
426-
427446 x_adj_global = x_adj .asarray ()
428447 y_adj_dist = Gop_MPI .H @ x_adj
429448 assert isinstance (y_adj_dist , pylops_mpi .DistributedArray )
430-
431449 y_adj = y_adj_dist .asarray ()
432450
451+ # Dot test
452+ dottest (Gop_MPI , x_fwd , y_dist , len (par ['n' ]) * np .prod (par ['n' ]), np .prod (par ['n' ]))
453+
433454 if rank == 0 :
434455 Gop = pylops .Gradient (dims = par ['n' ], sampling = par ['sampling' ],
435456 kind = kind , edge = par ['edge' ],
0 commit comments