Skip to content

Commit ce3e1e1

Browse files
committed
api: ensure SparseFunction distributes according to input distributor
1 parent edc9288 commit ce3e1e1

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

devito/data/data.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ def __setitem__(self, glb_idx, val, comm_type):
396396
processed.append(j)
397397
val_idx = as_tuple(processed)
398398
val = val[val_idx]
399-
else:
400-
# `val` is replicated`, `self` is replicated -> plain ndarray.__setitem__
401-
pass
402399
super().__setitem__(glb_idx, val)
403400
elif isinstance(val, Iterable):
404401
if self._is_decomposed:

devito/operator/operator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,8 @@ def apply(self, **kwargs):
991991
# Build the arguments list to invoke the kernel function
992992
with self._profiler.timer_on('arguments'):
993993
args = self.arguments(**kwargs)
994+
with switch_log_level(comm=args.comm):
995+
self._emit_args_profiling()
994996

995997
# Invoke kernel function with args
996998
arg_values = [args[p.name] for p in self.parameters]
@@ -1009,7 +1011,10 @@ def apply(self, **kwargs):
10091011
raise
10101012

10111013
# Perform error checking
1012-
self._postprocess_errors(retval)
1014+
with self._profiler.timer_on('postarguments'):
1015+
self._postprocess_errors(retval)
1016+
with switch_log_level(comm=args.comm):
1017+
self._emit_args_profiling('post')
10131018

10141019
# Post-process runtime arguments
10151020
self._postprocess_arguments(args, **kwargs)
@@ -1020,6 +1025,11 @@ def apply(self, **kwargs):
10201025

10211026
# Performance profiling
10221027

1028+
def _emit_args_profiling(self, tag=''):
1029+
fround = lambda i, n=100: ceil(i * n) / n
1030+
elapsed = fround(self._profiler.py_timers[f'{tag}arguments'])
1031+
debug(f"Operator `{self.name}` arguments {tag} processed in {elapsed:.2f} s")
1032+
10231033
def _emit_build_profiling(self):
10241034
if not is_log_enabled_for('PERF'):
10251035
return

devito/types/sparse.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __shape_setup__(cls, **kwargs):
9898
shape = kwargs.get('shape', kwargs.get('shape_global'))
9999
dimensions = kwargs.get('dimensions')
100100
npoint = kwargs.get('npoint', kwargs.get('npoint_global'))
101-
glb_npoint = SparseDistributor.decompose(npoint, grid.distributor)
101+
distributor = kwargs.get('distributor', SparseDistributor)
102+
glb_npoint = distributor.decompose(npoint, grid.distributor)
102103
# Plain SparseFunction construction with npoint.
103104
if shape is None:
104105
loc_shape = (glb_npoint[grid.distributor.myrank],)
@@ -184,7 +185,6 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
184185

185186
# Given an array or nothing, create dimension and SubFunction
186187
if key is not None:
187-
dimensions = (self._sparse_dim, Dimension(name='d'))
188188
if key.ndim > 2:
189189
dimensions = (self._sparse_dim, Dimension(name='d'),
190190
*mkdims("i", n=key.ndim-2))
@@ -211,14 +211,17 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
211211
else:
212212
dtype = dtype or self.dtype
213213

214+
if kwargs.get('init_subfunc', True):
215+
init = {'initializer': key}
216+
else:
217+
init = {}
214218
# Complex coordinates are not valid, so fall back to corresponding
215219
# real floating point type if dtype is complex.
216220
dtype = dtype(0).real.__class__
217-
218221
sf = SparseSubFunction(
219222
name=name, dtype=dtype, dimensions=dimensions,
220-
shape=shape, space_order=0, initializer=key, alias=self.alias,
221-
distributor=self._distributor, parent=self
223+
shape=shape, space_order=0, alias=self.alias,
224+
distributor=self._distributor, parent=self, **init
222225
)
223226

224227
if self.npoint == 0:

tests/test_rebuild.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,6 @@ def test_none_subfunc(self, sfunc):
9999

100100
assert s.coordinates is not None
101101

102-
# Explicity set coordinates to None
102+
# Explicitly set coordinates to None
103103
sr = s._rebuild(function=None, initializer=None, coordinates=None)
104104
assert sr.coordinates is None

0 commit comments

Comments
 (0)