Skip to content

Commit ec4b8c1

Browse files
Fix dottest dependency and use base_comm instead of MPI.COMM_WORLD (#134)
* Replace MPI.COMM_WORLD with self.base_comm, Fix dottest * Fix flake8 for local dev * Minor change
1 parent f9f5d13 commit ec4b8c1

File tree

6 files changed

+9
-15
lines changed

6 files changed

+9
-15
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def mask(self):
256256
257257
Returns
258258
-------
259-
engine : :obj:`list`
259+
mask : :obj:`list`
260260
"""
261261
return self._mask
262262

@@ -288,11 +288,7 @@ def rank(self):
288288
-------
289289
rank : :obj:`int`
290290
"""
291-
# cp.cuda.Device().id will give local rank
292-
# It works ok in the single-node multi-gpu environment.
293-
# But in multi-node environment, the function will break.
294-
# So we have to use MPI.COMM_WORLD() in both cases of base_comm (MPI and NCCL)
295-
return MPI.COMM_WORLD.Get_rank()
291+
return self.base_comm.Get_rank()
296292

297293
@property
298294
def size(self):
@@ -303,7 +299,7 @@ def size(self):
303299
-------
304300
size : :obj:`int`
305301
"""
306-
return MPI.COMM_WORLD.Get_size()
302+
return self.base_comm.Get_size()
307303

308304
@property
309305
def axis(self):
@@ -812,8 +808,8 @@ def __init__(self, distarrays: List, base_comm: MPI.Comm = MPI.COMM_WORLD):
812808
self.distarrays = distarrays
813809
self.narrays = len(distarrays)
814810
self.base_comm = base_comm
815-
self.rank = MPI.COMM_WORLD.Get_rank()
816-
self.size = MPI.COMM_WORLD.Get_size()
811+
self.rank = base_comm.Get_rank()
812+
self.size = base_comm.Get_size()
817813

818814
def __getitem__(self, index):
819815
return self.distarrays[index]

pylops_mpi/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# isort: skip_file
22

3-
# currently dottest create circular dependency with DistributedArray.py
4-
# from .dottest import *
3+
from .dottest import *
54
from .deps import *

pylops_mpi/utils/dottest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from pylops_mpi.DistributedArray import DistributedArray
7+
from pylops_mpi import DistributedArray
88
from pylops.utils.backend import to_numpy
99

1010

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ python_files = tests/*.py tests_nccl/*.py
55
[flake8]
66
ignore = E203, E501, W503, E402
77
per-file-ignores =
8-
__init__.py: F401, F403, F405
8+
__init__.py: F401, F403, F405
99
max-line-length = 88

tests/test_fredholm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_Fredholm1(par):
135135
y_adj_dist = Fop_MPI.H @ y_dist
136136
y_adj = y_adj_dist.asarray()
137137
# Dot test
138-
dottest(Fop_MPI, x, y_dist, par["nsl"] * par["nx"] * par["nz"],par["nsl"] * par["ny"] * par["nz"])
138+
dottest(Fop_MPI, x, y_dist, par["nsl"] * par["nx"] * par["nz"], par["nsl"] * par["ny"] * par["nz"])
139139

140140
if rank == 0:
141141
Fop = pylops.signalprocessing.Fredholm1(

tutorials/mdd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""
1515

1616
import numpy as np
17-
from scipy.signal import filtfilt
1817
from matplotlib import pyplot as plt
1918
from mpi4py import MPI
2019

0 commit comments

Comments
 (0)