Skip to content

Commit 874b8d2

Browse files
committed
mpi: prevent crashes with single rank
1 parent 4c654ff commit 874b8d2

File tree

2 files changed

+9
-29
lines changed

2 files changed

+9
-29
lines changed

devito/builtins/arithmetic.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -180,33 +180,13 @@ def inner(f, g):
180180
return f.dtype(n.data[0])
181181

182182

183-
@dv.switchconfig(log_level='ERROR')
184-
@check_builtins_args
185-
def mmin(f):
186-
"""
187-
Retrieve the minimum.
188-
189-
Parameters
190-
----------
191-
f : array_like or Function
192-
Input operand.
193-
"""
194-
if isinstance(f, dv.Constant):
195-
return f.data
196-
elif isinstance(f, dv.types.dense.DiscreteFunction):
197-
v = np.min(f.data_ro_domain)
198-
if f.grid is None or not dv.configuration['mpi']:
199-
return v.item()
200-
else:
201-
comm = f.grid.distributor.comm
202-
return comm.allreduce(v, dv.mpi.MPI.MIN).item()
203-
else:
204-
raise ValueError("Expected Function, got `%s`" % type(f))
183+
mmin = lambda f: _reduce_func(f, np.min, dv.mpi.MPI.MIN)
184+
mmax = lambda f: _reduce_func(f, np.max, dv.mpi.MPI.MAX)
205185

206186

207187
@dv.switchconfig(log_level='ERROR')
208188
@check_builtins_args
209-
def mmax(f):
189+
def _reduce_func(f, func, mfunc):
210190
"""
211191
Retrieve the maximum.
212192
@@ -218,11 +198,11 @@ def mmax(f):
218198
if isinstance(f, dv.Constant):
219199
return f.data
220200
elif isinstance(f, dv.types.dense.DiscreteFunction):
221-
v = np.max(f.data_ro_domain)
222-
if f.grid is None or not dv.configuration['mpi']:
223-
return v.item()
224-
else:
201+
v = func(f.data_ro_domain)
202+
if f.data._is_decomposed:
225203
comm = f.grid.distributor.comm
226-
return comm.allreduce(v, dv.mpi.MPI.MAX).item()
204+
return comm.allreduce(v, mfunc).item()
205+
else:
206+
return v.item()
227207
else:
228208
raise ValueError("Expected Function, got `%s`" % type(f))

devito/data/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def wrapper(data, *args, **kwargs):
197197
@property
198198
def _is_decomposed(self):
199199
return self._is_distributed and configuration['mpi'] and \
200-
self._distributor.comm.size > 1
200+
self._distributor.is_parallel
201201

202202
def __repr__(self):
203203
return super(Data, self._local).__repr__()

0 commit comments

Comments
 (0)