Skip to content

Commit cff361e

Browse files
committed
api: ensure SparseFunction distributes according to input distributor
1 parent edc9288 commit cff361e

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

devito/data/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,9 @@ def __setitem__(self, glb_idx, val, comm_type):
368368
# `val` is decomposed, `self` is replicated -> gatherall-like
369369
raise NotImplementedError
370370
elif isinstance(val, np.ndarray):
371+
if val.shape == self.view().shape:
372+
super().__setitem__(tuple(slice(None) for _ in val.shape), val)
373+
return
371374
if self._is_decomposed:
372375
# `val` is replicated, `self` is decomposed -> `val` gets decomposed
373376
glb_idx = self._normalize_index(glb_idx)

devito/operator/operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,8 @@ def apply(self, **kwargs):
991991
# Build the arguments list to invoke the kernel function
992992
with self._profiler.timer_on('arguments'):
993993
args = self.arguments(**kwargs)
994+
with switch_log_level(comm=args.comm):
995+
self._emit_args_profiling()
994996

995997
# Invoke kernel function with args
996998
arg_values = [args[p.name] for p in self.parameters]
@@ -1020,6 +1022,11 @@ def apply(self, **kwargs):
10201022

10211023
# Performance profiling
10221024

1025+
def _emit_args_profiling(self):
1026+
fround = lambda i, n=100: ceil(i * n) / n
1027+
elapsed = fround(self._profiler.py_timers['arguments'])
1028+
info(f"Operator `{self.name}` arguments processed in {elapsed:.2f} s")
1029+
10231030
def _emit_build_profiling(self):
10241031
if not is_log_enabled_for('PERF'):
10251032
return

devito/types/sparse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __shape_setup__(cls, **kwargs):
9898
shape = kwargs.get('shape', kwargs.get('shape_global'))
9999
dimensions = kwargs.get('dimensions')
100100
npoint = kwargs.get('npoint', kwargs.get('npoint_global'))
101-
glb_npoint = SparseDistributor.decompose(npoint, grid.distributor)
101+
distributor = kwargs.get('distributor', SparseDistributor)
102+
glb_npoint = distributor.decompose(npoint, grid.distributor)
102103
# Plain SparseFunction construction with npoint.
103104
if shape is None:
104105
loc_shape = (glb_npoint[grid.distributor.myrank],)
@@ -214,7 +215,6 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
214215
# Complex coordinates are not valid, so fall back to corresponding
215216
# real floating point type if dtype is complex.
216217
dtype = dtype(0).real.__class__
217-
218218
sf = SparseSubFunction(
219219
name=name, dtype=dtype, dimensions=dimensions,
220220
shape=shape, space_order=0, initializer=key, alias=self.alias,

0 commit comments

Comments
 (0)