22    Designed to run with n processes 
33    $ mpiexec -n 10 pytest test_blockdiag.py --with-mpi 
44""" 
5+ import  os 
6+ 
7+ if  int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )):
8+     import  cupy  as  np 
9+     from  cupy .testing  import  assert_allclose 
10+ 
11+     backend  =  "cupy" 
12+ else :
13+     import  numpy  as  np 
14+     from  numpy .testing  import  assert_allclose 
15+ 
16+     backend  =  "numpy" 
517from  mpi4py  import  MPI 
6- import  numpy  as  np 
7- from  numpy .testing  import  assert_allclose 
818import  pytest 
919
1020import  pylops 
1727par2j  =  {'ny' : 301 , 'nx' : 101 , 'dtype' : np .complex128 }
1828
1929np .random .seed (42 )
30+ rank  =  MPI .COMM_WORLD .Get_rank ()
31+ if  backend  ==  "cupy" :
32+     device_id  =  rank  %  np .cuda .runtime .getDeviceCount ()
33+     np .cuda .Device (device_id ).use ()
2034
2135
2236@pytest .mark .mpi (min_size = 2 ) 
@@ -27,11 +41,11 @@ def test_blockdiag(par):
2741    Op  =  pylops .MatrixMult (A = ((rank  +  1 ) *  np .ones (shape = (par ['ny' ], par ['nx' ]))).astype (par ['dtype' ]))
2842    BDiag_MPI  =  pylops_mpi .MPIBlockDiag (ops = [Op , ])
2943
30-     x  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['nx' ], dtype = par ['dtype' ])
44+     x  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['nx' ], dtype = par ['dtype' ],  engine = backend )
3145    x [:] =  np .ones (shape = par ['nx' ], dtype = par ['dtype' ])
3246    x_global  =  x .asarray ()
3347
34-     y  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['ny' ], dtype = par ['dtype' ])
48+     y  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['ny' ], dtype = par ['dtype' ],  engine = backend )
3549    y [:] =  np .ones (shape = par ['ny' ], dtype = par ['dtype' ])
3650    y_global  =  y .asarray ()
3751
@@ -68,16 +82,16 @@ def test_stacked_blockdiag(par):
6882    FirstDeriv_MPI  =  pylops_mpi .MPIFirstDerivative (dims = (par ['ny' ], par ['nx' ]), dtype = par ['dtype' ])
6983    StackedBDiag_MPI  =  pylops_mpi .MPIStackedBlockDiag (ops = [BDiag_MPI , FirstDeriv_MPI ])
7084
71-     dist1  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['nx' ], dtype = par ['dtype' ])
85+     dist1  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['nx' ], dtype = par ['dtype' ],  engine = backend )
7286    dist1 [:] =  np .ones (dist1 .local_shape , dtype = par ['dtype' ])
73-     dist2  =  pylops_mpi .DistributedArray (global_shape = par ['nx' ] *  par ['ny' ], dtype = par ['dtype' ])
87+     dist2  =  pylops_mpi .DistributedArray (global_shape = par ['nx' ] *  par ['ny' ], dtype = par ['dtype' ],  engine = backend )
7488    dist2 [:] =  np .ones (dist2 .local_shape , dtype = par ['dtype' ])
7589    x  =  pylops_mpi .StackedDistributedArray (distarrays = [dist1 , dist2 ])
7690    x_global  =  x .asarray ()
7791
78-     dist1  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['ny' ], dtype = par ['dtype' ])
92+     dist1  =  pylops_mpi .DistributedArray (global_shape = size  *  par ['ny' ], dtype = par ['dtype' ],  engine = backend )
7993    dist1 [:] =  np .ones (dist1 .local_shape , dtype = par ['dtype' ])
80-     dist2  =  pylops_mpi .DistributedArray (global_shape = par ['nx' ] *  par ['ny' ], dtype = par ['dtype' ])
94+     dist2  =  pylops_mpi .DistributedArray (global_shape = par ['nx' ] *  par ['ny' ], dtype = par ['dtype' ],  engine = backend )
8195    dist2 [:] =  np .ones (dist2 .local_shape , dtype = par ['dtype' ])
8296    y  =  pylops_mpi .StackedDistributedArray (distarrays = [dist1 , dist2 ])
8397    y_global  =  y .asarray ()
0 commit comments