Skip to content

Commit 99eb5ed

Browse files
committed
api: ensure SparseFunction distributes according to input distributor
1 parent edc9288 commit 99eb5ed

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

devito/data/data.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ def __setitem__(self, glb_idx, val, comm_type):
396396
processed.append(j)
397397
val_idx = as_tuple(processed)
398398
val = val[val_idx]
399-
else:
400-
# `val` is replicated`, `self` is replicated -> plain ndarray.__setitem__
401-
pass
402399
super().__setitem__(glb_idx, val)
403400
elif isinstance(val, Iterable):
404401
if self._is_decomposed:

devito/operator/operator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,8 @@ def apply(self, **kwargs):
991991
# Build the arguments list to invoke the kernel function
992992
with self._profiler.timer_on('arguments'):
993993
args = self.arguments(**kwargs)
994+
with switch_log_level(comm=args.comm):
995+
self._emit_args_profiling()
994996

995997
# Invoke kernel function with args
996998
arg_values = [args[p.name] for p in self.parameters]
@@ -1009,7 +1011,10 @@ def apply(self, **kwargs):
10091011
raise
10101012

10111013
# Perform error checking
1012-
self._postprocess_errors(retval)
1014+
with self._profiler.timer_on('post-arguments'):
1015+
self._postprocess_errors(retval)
1016+
with switch_log_level(comm=args.comm):
1017+
self._emit_args_profiling('post')
10131018

10141019
# Post-process runtime arguments
10151020
self._postprocess_arguments(args, **kwargs)
@@ -1020,6 +1025,11 @@ def apply(self, **kwargs):
10201025

10211026
# Performance profiling
10221027

1028+
def _emit_args_profiling(self, tag=''):
1029+
fround = lambda i, n=100: ceil(i * n) / n
1030+
elapsed = fround(self._profiler.py_timers[f'{tag}-arguments'])
1031+
debug(f"Operator `{self.name}` arguments {tag}-processed in {elapsed:.2f} s")
1032+
10231033
def _emit_build_profiling(self):
10241034
if not is_log_enabled_for('PERF'):
10251035
return

devito/types/sparse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __shape_setup__(cls, **kwargs):
9898
shape = kwargs.get('shape', kwargs.get('shape_global'))
9999
dimensions = kwargs.get('dimensions')
100100
npoint = kwargs.get('npoint', kwargs.get('npoint_global'))
101-
glb_npoint = SparseDistributor.decompose(npoint, grid.distributor)
101+
distributor = kwargs.get('distributor', SparseDistributor)
102+
glb_npoint = distributor.decompose(npoint, grid.distributor)
102103
# Plain SparseFunction construction with npoint.
103104
if shape is None:
104105
loc_shape = (glb_npoint[grid.distributor.myrank],)
@@ -146,7 +147,7 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
146147
for k in keys:
147148
if k not in kwargs:
148149
continue
149-
elif kwargs[k] is None:
150+
elif kwargs[k] is None and inkwargs:
150151
# In cases such as rebuild,
151152
# the subfunction may be passed explicitly as None
152153
return None
@@ -214,7 +215,6 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
214215
# Complex coordinates are not valid, so fall back to corresponding
215216
# real floating point type if dtype is complex.
216217
dtype = dtype(0).real.__class__
217-
218218
sf = SparseSubFunction(
219219
name=name, dtype=dtype, dimensions=dimensions,
220220
shape=shape, space_order=0, initializer=key, alias=self.alias,

0 commit comments

Comments
 (0)