Skip to content

Commit 66e3b16

Browse files
committed
ensure unqiue gpu device for each mpi rank in CuPy MPI tests
1 parent 33121a5 commit 66e3b16

11 files changed

+89
-9
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def _compute_vector_norm(self, local_array: NDArray,
697697
# TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
698698
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
699699
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
700-
if self.engine=="cupy" and self.base_comm_nccl is None:
700+
if self.engine == "cupy" and self.base_comm_nccl is None:
701701
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
702702
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
703703
else:

tests/test_blockdiag.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
2828

2929
np.random.seed(42)
30+
rank = MPI.COMM_WORLD.Get_rank()
31+
if backend == "cupy":
32+
device_count = np.cuda.runtime.getDeviceCount()
33+
device_id = int(
34+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
35+
or rank % np.cuda.runtime.getDeviceCount()
36+
)
37+
np.cuda.Device(device_id).use()
3038

3139

3240
@pytest.mark.mpi(min_size=2)

tests/test_derivative.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
np.random.seed(42)
2626
rank = MPI.COMM_WORLD.Get_rank()
2727
size = MPI.COMM_WORLD.Get_size()
28+
if backend == "cupy":
29+
device_count = np.cuda.runtime.getDeviceCount()
30+
device_id = int(
31+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
32+
or rank % np.cuda.runtime.getDeviceCount()
33+
)
34+
np.cuda.Device(device_id).use()
35+
2836

2937
par1 = {
3038
"nz": 600,

tests/test_distributedarray.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from pylops_mpi.DistributedArray import local_split
2222

2323
np.random.seed(42)
24+
rank = MPI.COMM_WORLD.Get_rank()
25+
if backend == "cupy":
26+
device_count = np.cuda.runtime.getDeviceCount()
27+
device_id = int(
28+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
29+
or rank % np.cuda.runtime.getDeviceCount()
30+
)
31+
np.cuda.Device(device_id).use()
2432

2533
par1 = {'global_shape': (500, 501),
2634
'partition': Partition.SCATTER, 'dtype': np.float64,
@@ -206,7 +214,7 @@ def test_distributed_norm(par):
206214

207215
# TODO (tharitt): FAIL with CuPy + MPI for inf norm
208216
assert_allclose(arr.norm(ord=np.inf, axis=par['axis']),
209-
np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
217+
np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
210218
assert_allclose(arr.norm(), np.linalg.norm(par['x'].flatten()), rtol=1e-13)
211219

212220

tests/test_fredholm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
np.random.seed(42)
3030
rank = MPI.COMM_WORLD.Get_rank()
3131
size = MPI.COMM_WORLD.Get_size()
32+
if backend == "cupy":
33+
device_count = np.cuda.runtime.getDeviceCount()
34+
device_id = int(
35+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
36+
or rank % np.cuda.runtime.getDeviceCount()
37+
)
38+
np.cuda.Device(device_id).use()
3239

3340
par1 = {
3441
"nsl": 21,

tests/test_linearop.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
np.random.seed(42)
3232
rank = MPI.COMM_WORLD.Get_rank()
3333
size = MPI.COMM_WORLD.Get_size()
34+
if backend == "cupy":
35+
device_count = np.cuda.runtime.getDeviceCount()
36+
device_id = int(
37+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
38+
or rank % np.cuda.runtime.getDeviceCount()
39+
)
40+
np.cuda.Device(device_id).use()
3441

3542
par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64}
3643
par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
@@ -142,7 +149,7 @@ def test_power(par):
142149
Op = pylops.MatrixMult(A=((rank + 1) * np.ones(shape=(par['ny'], par['nx']))).astype(par['dtype']),
143150
dtype=par['dtype'])
144151
BDiag_MPI = MPIBlockDiag(ops=[Op, ])
145-
152+
146153
# Power Operator
147154
Pop_MPI = BDiag_MPI ** 3
148155

@@ -166,7 +173,7 @@ def test_power(par):
166173
ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in
167174
range(size)]
168175
BDiag = pylops.BlockDiag(ops=ops)
169-
Pop = BDiag * BDiag * BDiag ## temporarely replaced BDiag ** 3 until bug in PyLops is fixed
176+
Pop = BDiag * BDiag * BDiag # temporarely replaced BDiag ** 3 until bug in PyLops is fixed
170177
assert_allclose(Pop_x_np, Pop @ x_global, rtol=1e-9)
171178
assert_allclose(Pop_y_np, Pop.H @ y_global, rtol=1e-9)
172179

tests/test_matrixmult.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
from mpi4py import MPI
2020
import pytest
2121

22-
from pylops.basicoperators import FirstDerivative, Identity
22+
from pylops.basicoperators import FirstDerivative
2323
from pylops_mpi import DistributedArray, Partition
2424
from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag
2525

2626
np.random.seed(42)
2727
base_comm = MPI.COMM_WORLD
2828
size = base_comm.Get_size()
29-
29+
rank = MPI.COMM_WORLD.Get_rank()
30+
if backend == "cupy":
31+
device_count = np.cuda.runtime.getDeviceCount()
32+
device_id = int(
33+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
34+
or rank % np.cuda.runtime.getDeviceCount()
35+
)
36+
np.cuda.Device(device_id).use()
3037
# Define test cases: (N, K, M, dtype_str)
3138
# M, K, N are matrix dimensions A(N,K), B(K,M)
3239
# P_prime will be ceil(sqrt(size)).
@@ -39,6 +46,7 @@
3946
pytest.param(2, 1, 3, "float32", id="f32_2_1_3",),
4047
]
4148

49+
4250
def _reorganize_local_matrix(x_dist, N, M, blk_cols, p_prime):
4351
"""Re-organize distributed array in local matrix
4452
"""
@@ -66,9 +74,9 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
6674
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
6775
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
6876

69-
comm, rank, row_id, col_id, is_active = \
70-
MPIMatrixMult.active_grid_comm(base_comm, N, M)
71-
if not is_active: return
77+
comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M)
78+
if not is_active:
79+
return
7280

7381
size = comm.Get_size()
7482
p_prime = math.isqrt(size)

tests/test_solver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838

3939
size = MPI.COMM_WORLD.Get_size()
4040
rank = MPI.COMM_WORLD.Get_rank()
41+
if backend == "cupy":
42+
device_count = np.cuda.runtime.getDeviceCount()
43+
device_id = int(
44+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
45+
or rank % np.cuda.runtime.getDeviceCount()
46+
)
47+
np.cuda.Device(device_id).use()
4148

4249
par1 = {
4350
"ny": 11,

tests/test_stack.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121
import pylops_mpi
2222
from pylops_mpi.utils.dottest import dottest
2323

24+
rank = MPI.COMM_WORLD.Get_rank()
25+
if backend == "cupy":
26+
device_count = np.cuda.runtime.getDeviceCount()
27+
device_id = int(
28+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
29+
or rank % np.cuda.runtime.getDeviceCount()
30+
)
31+
np.cuda.Device(device_id).use()
32+
2433
par1 = {'ny': 101, 'nx': 101, 'imag': 0, 'dtype': np.float64}
2534
par1j = {'ny': 101, 'nx': 101, 'imag': 1j, 'dtype': np.complex128}
2635
par2 = {'ny': 301, 'nx': 101, 'imag': 0, 'dtype': np.float64}

tests/test_stackedarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,19 @@
1717
import numpy as npp
1818
import pytest
1919

20+
from mpi4py import MPI
2021
from pylops_mpi import DistributedArray, Partition, StackedDistributedArray
2122

2223
np.random.seed(42)
24+
rank = MPI.COMM_WORLD.Get_rank()
25+
if backend == "cupy":
26+
device_count = np.cuda.runtime.getDeviceCount()
27+
device_id = int(
28+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
29+
or rank % np.cuda.runtime.getDeviceCount()
30+
)
31+
np.cuda.Device(device_id).use()
32+
2333

2434
par1 = {'global_shape': (500, 501),
2535
'partition': Partition.SCATTER, 'dtype': np.float64,

0 commit comments

Comments
 (0)