Skip to content

Commit eb9b407

Browse files
committed
Replace MPI.COMM_WORLD with self.base_comm, Fix dottest
1 parent f9f5d13 commit eb9b407

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def mask(self):
256256
257257
Returns
258258
-------
259-
engine : :obj:`list`
259+
mask : :obj:`list`
260260
"""
261261
return self._mask
262262

@@ -288,11 +288,7 @@ def rank(self):
288288
-------
289289
rank : :obj:`int`
290290
"""
291-
# cp.cuda.Device().id will give local rank
292-
# It works ok in the single-node multi-gpu environment.
293-
# But in multi-node environment, the function will break.
294-
# So we have to use MPI.COMM_WORLD() in both cases of base_comm (MPI and NCCL)
295-
return MPI.COMM_WORLD.Get_rank()
291+
return self.base_comm.Get_rank()
296292

297293
@property
298294
def size(self):
@@ -303,7 +299,7 @@ def size(self):
303299
-------
304300
size : :obj:`int`
305301
"""
306-
return MPI.COMM_WORLD.Get_size()
302+
return self.base_comm.Get_size()
307303

308304
@property
309305
def axis(self):
@@ -812,8 +808,8 @@ def __init__(self, distarrays: List, base_comm: MPI.Comm = MPI.COMM_WORLD):
812808
self.distarrays = distarrays
813809
self.narrays = len(distarrays)
814810
self.base_comm = base_comm
815-
self.rank = MPI.COMM_WORLD.Get_rank()
816-
self.size = MPI.COMM_WORLD.Get_size()
811+
self.rank = self.base_comm.Get_rank()
812+
self.size = self.base_comm.Get_size()
817813

818814
def __getitem__(self, index):
819815
return self.distarrays[index]

pylops_mpi/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# isort: skip_file
22

3-
# currently dottest create circular dependency with DistributedArray.py
4-
# from .dottest import *
3+
from .dottest import *
54
from .deps import *

pylops_mpi/utils/dottest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from pylops_mpi.DistributedArray import DistributedArray
7+
from pylops_mpi import DistributedArray
88
from pylops.utils.backend import to_numpy
99

1010

0 commit comments

Comments
 (0)