|
19 | 19 | from mpi4py import MPI |
20 | 20 | import pytest |
21 | 21 |
|
22 | | -from pylops.basicoperators import FirstDerivative, Identity |
| 22 | +from pylops.basicoperators import FirstDerivative |
23 | 23 | from pylops_mpi import DistributedArray, Partition |
24 | 24 | from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag |
25 | 25 |
|
26 | 26 | np.random.seed(42) |
27 | 27 | base_comm = MPI.COMM_WORLD |
28 | 28 | size = base_comm.Get_size() |
29 | | - |
| 29 | +rank = MPI.COMM_WORLD.Get_rank() |
| 30 | +if backend == "cupy": |
| 31 | + device_count = np.cuda.runtime.getDeviceCount() |
| 32 | + device_id = int( |
| 33 | + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") |
| 34 | + or rank % np.cuda.runtime.getDeviceCount() |
| 35 | + ) |
| 36 | + np.cuda.Device(device_id).use() |
30 | 37 | # Define test cases: (N, K, M, dtype_str) |
31 | 38 | # M, K, N are matrix dimensions A(N,K), B(K,M) |
32 | 39 | # P_prime will be ceil(sqrt(size)). |
|
39 | 46 | pytest.param(2, 1, 3, "float32", id="f32_2_1_3",), |
40 | 47 | ] |
41 | 48 |
|
| 49 | + |
42 | 50 | def _reorganize_local_matrix(x_dist, N, M, blk_cols, p_prime): |
43 | 51 | """Re-organize distributed array in local matrix |
44 | 52 | """ |
@@ -66,9 +74,9 @@ def test_MPIMatrixMult(N, K, M, dtype_str): |
66 | 74 | cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 |
67 | 75 | base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 |
68 | 76 |
|
69 | | - comm, rank, row_id, col_id, is_active = \ |
70 | | - MPIMatrixMult.active_grid_comm(base_comm, N, M) |
71 | | - if not is_active: return |
| 77 | + comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M) |
| 78 | + if not is_active: |
| 79 | + return |
72 | 80 |
|
73 | 81 | size = comm.Get_size() |
74 | 82 | p_prime = math.isqrt(size) |
|
0 commit comments