22 Designed to run with n processes
33 $ mpiexec -n 10 pytest test_blockdiag.py --with-mpi
44"""
5+ import os
6+
7+ if int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )):
8+ import cupy as np
9+ from cupy .testing import assert_allclose
10+
11+ backend = "cupy"
12+ else :
13+ import numpy as np
14+ from numpy .testing import assert_allclose
15+
16+ backend = "numpy"
517from mpi4py import MPI
6- import numpy as np
7- from numpy .testing import assert_allclose
818import pytest
919
1020import pylops
1727par2j = {'ny' : 301 , 'nx' : 101 , 'dtype' : np .complex128 }
1828
1929np .random .seed (42 )
30+ rank = MPI .COMM_WORLD .Get_rank ()
31+ if backend == "cupy" :
32+ device_id = rank % np .cuda .runtime .getDeviceCount ()
33+ np .cuda .Device (device_id ).use ()
2034
2135
2236@pytest .mark .mpi (min_size = 2 )
@@ -27,11 +41,11 @@ def test_blockdiag(par):
2741 Op = pylops .MatrixMult (A = ((rank + 1 ) * np .ones (shape = (par ['ny' ], par ['nx' ]))).astype (par ['dtype' ]))
2842 BDiag_MPI = pylops_mpi .MPIBlockDiag (ops = [Op , ])
2943
30- x = pylops_mpi .DistributedArray (global_shape = size * par ['nx' ], dtype = par ['dtype' ])
44+ x = pylops_mpi .DistributedArray (global_shape = size * par ['nx' ], dtype = par ['dtype' ], engine = backend )
3145 x [:] = np .ones (shape = par ['nx' ], dtype = par ['dtype' ])
3246 x_global = x .asarray ()
3347
34- y = pylops_mpi .DistributedArray (global_shape = size * par ['ny' ], dtype = par ['dtype' ])
48+ y = pylops_mpi .DistributedArray (global_shape = size * par ['ny' ], dtype = par ['dtype' ], engine = backend )
3549 y [:] = np .ones (shape = par ['ny' ], dtype = par ['dtype' ])
3650 y_global = y .asarray ()
3751
@@ -68,16 +82,16 @@ def test_stacked_blockdiag(par):
6882 FirstDeriv_MPI = pylops_mpi .MPIFirstDerivative (dims = (par ['ny' ], par ['nx' ]), dtype = par ['dtype' ])
6983 StackedBDiag_MPI = pylops_mpi .MPIStackedBlockDiag (ops = [BDiag_MPI , FirstDeriv_MPI ])
7084
71- dist1 = pylops_mpi .DistributedArray (global_shape = size * par ['nx' ], dtype = par ['dtype' ])
85+ dist1 = pylops_mpi .DistributedArray (global_shape = size * par ['nx' ], dtype = par ['dtype' ], engine = backend )
7286 dist1 [:] = np .ones (dist1 .local_shape , dtype = par ['dtype' ])
73- dist2 = pylops_mpi .DistributedArray (global_shape = par ['nx' ] * par ['ny' ], dtype = par ['dtype' ])
87+ dist2 = pylops_mpi .DistributedArray (global_shape = par ['nx' ] * par ['ny' ], dtype = par ['dtype' ], engine = backend )
7488 dist2 [:] = np .ones (dist2 .local_shape , dtype = par ['dtype' ])
7589 x = pylops_mpi .StackedDistributedArray (distarrays = [dist1 , dist2 ])
7690 x_global = x .asarray ()
7791
78- dist1 = pylops_mpi .DistributedArray (global_shape = size * par ['ny' ], dtype = par ['dtype' ])
92+ dist1 = pylops_mpi .DistributedArray (global_shape = size * par ['ny' ], dtype = par ['dtype' ], engine = backend )
7993 dist1 [:] = np .ones (dist1 .local_shape , dtype = par ['dtype' ])
80- dist2 = pylops_mpi .DistributedArray (global_shape = par ['nx' ] * par ['ny' ], dtype = par ['dtype' ])
94+ dist2 = pylops_mpi .DistributedArray (global_shape = par ['nx' ] * par ['ny' ], dtype = par ['dtype' ], engine = backend )
8195 dist2 [:] = np .ones (dist2 .local_shape , dtype = par ['dtype' ])
8296 y = pylops_mpi .StackedDistributedArray (distarrays = [dist1 , dist2 ])
8397 y_global = y .asarray ()
0 commit comments