Skip to content

Commit 5dd4f97

Browse files
committed
api: revamp staggered internal for correct rebuilding
1 parent c1c5806 commit 5dd4f97

File tree

5 files changed

+84
-46
lines changed

5 files changed

+84
-46
lines changed

devito/types/basic.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def __new__(cls, *args, **kwargs):
699699
args, kwargs = cls.__args_setup__(*args, **kwargs)
700700

701701
# Extract the `indices`, as perhaps they're explicitly provided
702-
dimensions, indices = cls.__indices_setup__(*args, **kwargs)
702+
dimensions, indices, staggered = cls.__indices_setup__(*args, **kwargs)
703703

704704
# If it's an alias or simply has a different name, ignore `function`.
705705
# These cases imply the construction of a new AbstractFunction off
@@ -743,6 +743,7 @@ def __new__(cls, *args, **kwargs):
743743
# when executing __init_finalize__
744744
newobj._name = name
745745
newobj._dimensions = dimensions
746+
newobj._staggered = staggered
746747
newobj._shape = cls.__shape_setup__(**kwargs)
747748
newobj._dtype = cls.__dtype_setup__(**kwargs)
748749

@@ -925,6 +926,11 @@ def indices(self):
925926
"""The indices of the object."""
926927
return DimensionTuple(*self.args, getters=self.dimensions)
927928

929+
@property
930+
def staggered(self):
931+
"""The staggered indices of the object."""
932+
return DimensionTuple(*self._staggered, getters=self.dimensions)
933+
928934
@property
929935
def indices_ref(self):
930936
"""The reference indices of the object (indices at first creation)."""
@@ -1428,8 +1434,7 @@ def _new(cls, *args, **kwargs):
14281434
return sympy.ImmutableDenseMatrix(*args)
14291435
# Initialized with constructed object
14301436
newobj.__init_finalize__(newobj.rows, newobj.cols, newobj.flat(),
1431-
grid=grid, dimensions=dimensions,
1432-
name=kwargs['name'])
1437+
grid=grid, dimensions=dimensions)
14331438
else:
14341439
# Initialize components and create new Matrix from standard
14351440
# Devito inputs
@@ -1480,7 +1485,15 @@ def grid(self):
14801485

14811486
@property
14821487
def name(self):
1483-
return self._name
1488+
for c in self.values():
1489+
try:
1490+
return c.name.split('_')[0]
1491+
except AttributeError:
1492+
# `c` is not a devito object
1493+
pass
1494+
# If we end up here, then we have no devito objects
1495+
# in the matrix, so we ust return the class name
1496+
return self.__class__.__name__
14841497

14851498
def _rebuild(self, *args, **kwargs):
14861499
# We need to rebuild the components with the new name then
@@ -1489,7 +1502,7 @@ def _rebuild(self, *args, **kwargs):
14891502
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
14901503
for f in self.flat()]
14911504
# Rebuild the matrix with the new components
1492-
return self._new(comps, name=newname)
1505+
return self._new(comps)
14931506

14941507
func = _rebuild
14951508

devito/types/dense.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable):
6666
__rkwargs__ = AbstractFunction.__rkwargs__ + ('staggered', 'coefficients')
6767

6868
def __init_finalize__(self, *args, function=None, **kwargs):
69-
# Staggering metadata
70-
self._staggered = self.__staggered_setup__(**kwargs)
71-
7269
# Now that *all* __X_setup__ hooks have been called, we can let the
7370
# superclass constructor do its job
7471
super().__init_finalize__(*args, **kwargs)
@@ -180,18 +177,6 @@ def __coefficients_setup__(self, **kwargs):
180177
" not %s" % (str(fd_weights_registry), coeffs))
181178
return coeffs
182179

183-
def __staggered_setup__(self, **kwargs):
184-
"""
185-
Setup staggering-related metadata. This method assigns:
186-
187-
* 0 to non-staggered dimensions;
188-
* 1 to staggered dimensions.
189-
"""
190-
staggered = kwargs.get('staggered', None)
191-
if staggered is CELL:
192-
staggered = self.dimensions
193-
return staggered
194-
195180
@cached_property
196181
def _functions(self):
197182
return {self.function}
@@ -208,10 +193,6 @@ def _mem_external(self):
208193
def _mem_heap(self):
209194
return True
210195

211-
@property
212-
def staggered(self):
213-
return self._staggered
214-
215196
@property
216197
def coefficients(self):
217198
"""Form of the coefficients of the function."""
@@ -1077,34 +1058,49 @@ def _eval_at(self, func):
10771058
return self.subs(mapper)
10781059
return self
10791060

1061+
@classmethod
1062+
def __staggered_setup__(cls, dimensions, **kwargs):
1063+
"""
1064+
Setup staggering-related metadata. This method assigns:
1065+
1066+
* 0 to non-staggered dimensions;
1067+
* 1 to staggered dimensions.
1068+
"""
1069+
stagg = kwargs.get('staggered', None)
1070+
if stagg is CELL:
1071+
staggered = (sympy.S.One for d in dimensions)
1072+
elif stagg in [None, NODE]:
1073+
staggered = (sympy.S.Zero for d in dimensions)
1074+
elif all(is_integer(s) for s in as_tuple(stagg)):
1075+
# Staggering is already a tuple likely from rebuild
1076+
assert len(stagg) == len(dimensions)
1077+
return tuple(stagg)
1078+
else:
1079+
staggered = (sympy.S.One if d in as_tuple(stagg) else sympy.S.Zero
1080+
for d in dimensions)
1081+
return tuple(staggered)
1082+
10801083
@classmethod
10811084
def __indices_setup__(cls, *args, **kwargs):
10821085
grid = kwargs.get('grid')
10831086
dimensions = kwargs.get('dimensions')
1087+
staggered = kwargs.get('staggered')
1088+
10841089
if grid is None:
10851090
if dimensions is None:
10861091
raise TypeError("Need either `grid` or `dimensions`")
10871092
elif dimensions is None:
10881093
dimensions = grid.dimensions
10891094

1095+
staggered = cls.__staggered_setup__(dimensions, staggered=staggered)
10901096
if args:
10911097
assert len(args) == len(dimensions)
1092-
return tuple(dimensions), tuple(args)
1093-
1094-
# Staggered indices
1095-
staggered = kwargs.get("staggered", None)
1096-
if staggered in [None, NODE]:
1097-
staggered_indices = dimensions
1098-
elif staggered == CELL:
1099-
staggered_indices = [d + d.spacing / 2 for d in dimensions]
1098+
staggered_indices = tuple(args)
11001099
else:
1101-
mapper = {d: d for d in dimensions}
1102-
for s in as_tuple(staggered):
1103-
c, s = s.as_coeff_Mul()
1104-
mapper.update({s: s + c * s.spacing / 2})
1105-
staggered_indices = mapper.values()
1106-
1107-
return tuple(dimensions), tuple(staggered_indices)
1100+
# Staggered indices
1101+
staggered_indices = (d + i * d.spacing / 2
1102+
for d, i in zip(dimensions, staggered))
1103+
return tuple(dimensions), tuple(staggered_indices), staggered
11081104

11091105
@property
11101106
def is_Staggered(self):
@@ -1604,7 +1600,7 @@ def __indices_setup__(cls, **kwargs):
16041600
# Sanity check
16051601
assert not any(d.is_NonlinearDerived for d in dimensions)
16061602

1607-
return dimensions, dimensions
1603+
return dimensions, dimensions, (sympy.S.Zero for _ in dimensions)
16081604

16091605
def __halo_setup__(self, **kwargs):
16101606
pointer_dim = kwargs.get('pointer_dim')

devito/types/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def __init_finalize__(self, *args, **kwargs):
7373
super().__init_finalize__(*args, **kwargs)
7474
grid = kwargs.get('grid')
7575
dimensions = kwargs.get('dimensions')
76-
inds, _ = Function.__indices_setup__(grid=grid,
77-
dimensions=dimensions)
76+
inds, _, _ = Function.__indices_setup__(grid=grid,
77+
dimensions=dimensions)
7878
self._space_dimensions = inds
7979

8080
@classmethod

tests/test_staggered_utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import numpy as np
33
from sympy import simplify
44

5-
from devito import (Function, Grid, NODE, VectorTimeFunction,
6-
TimeFunction, Eq, Operator, div)
7-
from devito.tools import powerset
5+
from devito import (Function, Grid, NODE, CELL, VectorTimeFunction,
6+
TimeFunction, Eq, Operator, div, Dimension)
7+
from devito.tools import powerset, as_tuple
88

99

1010
@pytest.mark.parametrize('ndim', [1, 2, 3])
@@ -160,3 +160,30 @@ def test_staggered_div():
160160
op2.apply(time_M=0)
161161

162162
assert np.allclose(p1.data[:], p2.data[:], atol=0, rtol=1e-5)
163+
164+
165+
@pytest.mark.parametrize('stagg', [
166+
'NODE', 'CELL', 'x', 'y', 'z',
167+
'(x, y)', '(x, z)', '(y, z)', '(x, y, z)'])
168+
def test_staggered_rebuild(stagg):
169+
grid = Grid(shape=(5, 5, 5))
170+
x, y, z = grid.dimensions # noqa
171+
stagg = eval(stagg)
172+
173+
f = Function(name='f', grid=grid, space_order=4, staggered=stagg)
174+
assert tuple(f.staggered.getters.keys()) == grid.dimensions
175+
176+
new_dims = (Dimension('x1'), Dimension('y1'), Dimension('z1'))
177+
f2 = f.func(dimensions=new_dims)
178+
179+
assert f2.dimensions == new_dims
180+
assert tuple(f2.staggered) == tuple(f.staggered)
181+
assert tuple(f2.staggered.getters.keys()) == new_dims
182+
183+
# Check that rebuild correctly set the staggered indices
184+
# with the new dimensions
185+
for (d, nd) in zip(grid.dimensions, new_dims):
186+
if d in as_tuple(stagg) or stagg is CELL:
187+
assert f2.indices[nd] == nd + nd.spacing / 2
188+
else:
189+
assert f2.indices[nd] == nd

tests/test_tensors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,13 @@ def test_rebuild(func1):
455455
assert j.dimensions == i.dimensions
456456

457457
new_dims = [Dimension(name=f'{i.name}1') for i in grid.dimensions]
458+
if f1.is_TimeDependent:
459+
new_dims = [f1[0].time_dim] + new_dims
458460
f3 = f1.func(dimensions=new_dims)
459461
assert f3.grid == grid
460462
assert f3.name == f1.name
461463

462464
for (i, j) in zip(f1.flat(), f3.flat()):
463465
assert j.name == i.name
464466
assert j.grid == i.grid
465-
assert j.dimensions == new_dims
467+
assert j.dimensions == tuple(new_dims)

0 commit comments

Comments
 (0)