Skip to content

Commit 2d32230

Browse files
committed
api: prevent evaluated derivatives to be re-evaluted
1 parent 24c4cfb commit 2d32230

File tree

9 files changed

+63
-40
lines changed

9 files changed

+63
-40
lines changed

devito/finite_differences/derivative.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,20 +335,6 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
335335
except AttributeError:
336336
raise TypeError("fd_order incompatible with dimensions") from None
337337

338-
if isinstance(self.expr, Derivative):
339-
# In case this was called on a perfect cross-derivative `u.dxdy`
340-
# we need to propagate the call to the nested derivative
341-
rkwe = dict(rkw)
342-
rkwe.pop('weights', None)
343-
if 'x0' in rkwe:
344-
rkwe['x0'] = self._filter_dims(self.expr._filter_dims(rkw['x0']),
345-
neg=True)
346-
if fd_order is not None:
347-
fdo = self.expr._filter_dims(_fd_order)
348-
if fdo:
349-
rkwe['fd_order'] = fdo
350-
rkw['expr'] = self.expr(**rkwe)
351-
352338
if fd_order is not None:
353339
rkw['fd_order'] = self._filter_dims(_fd_order, as_tuple=True)
354340

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,11 @@ def _evaluate(self, **kwargs):
950950

951951

952952
class DiffDerivative(IndexDerivative, DifferentiableOp):
953-
pass
953+
954+
def _eval_at(self, func):
955+
# Like EvalDerivative, a DiffDerivative must have already been evaluated
956+
# at a valid x0 and should not be re-evaluated at a different location
957+
return self
954958

955959

956960
# SymPy args ordering is the same for Derivatives and IndexDerivatives
@@ -998,6 +1002,11 @@ def _new_rawargs(self, *args, **kwargs):
9981002
kwargs.pop('is_commutative', None)
9991003
return self.func(*args, **kwargs)
10001004

1005+
def _eval_at(self, func):
1006+
# An EvalDerivative must have already been evaluated at a valid x0
1007+
# and should not be re-evaluated at a different location
1008+
return self
1009+
10011010

10021011
class diffify:
10031012

devito/types/basic.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -989,25 +989,48 @@ def c0(self):
989989
def _eval_deriv(self):
990990
return self
991991

992-
@property
992+
@cached_property
993993
def _grid_map(self):
994994
"""
995995
Mapper of off-grid interpolation points indices for each dimension.
996996
"""
997997
mapper = {}
998+
subs = {}
998999
for i, j, d in zip(self.indices, self.indices_ref, self.dimensions):
9991000
# Two indices are aligned if they differ by an Integer*spacing.
1000-
v = (i - j)/d.spacing
1001+
if not i.has(d):
1002+
# Maybe a SubDimension
1003+
dims = {sd for sd in i.free_symbols if getattr(sd, 'is_Dimension', False)
1004+
and d in sd._defines}
1005+
1006+
# More than one Dimension, cannot handle
1007+
if len(dims) != 1:
1008+
continue
1009+
1010+
# SubDimensions -> Dimension substitutions for interpolation
1011+
sd = dims.pop()
1012+
v = (i - j._subs(d, sd))/d.spacing
1013+
i = i._subs(sd, d)
1014+
subs[d] = sd
1015+
else:
1016+
v = (i - j)/d.spacing
1017+
10011018
try:
10021019
if not isinstance(v, sympy.Number) or int(v) == v:
1020+
# Skip if index is on grid
10031021
continue
1004-
# Skip if index is just a Symbol or integer
10051022
elif (i.is_Symbol and not i.has(d)) or i.is_Integer:
1023+
# Skip if index is just a Symbol or integer
10061024
continue
10071025
else:
10081026
mapper.update({d: i})
10091027
except (AttributeError, TypeError):
10101028
mapper.update({d: i})
1029+
1030+
# Substitutions for SubDimensions
1031+
if mapper:
1032+
mapper['subs'] = subs
1033+
10111034
return mapper
10121035

10131036
def _evaluate(self, **kwargs):
@@ -1019,29 +1042,32 @@ def _evaluate(self, **kwargs):
10191042
This allow to evaluate off grid points as EvalDerivative that are better
10201043
for the compiler.
10211044
"""
1045+
mapper = self._grid_map
1046+
subs = mapper.pop('subs', {})
10221047
# Average values if at a location not on the Function's grid
1023-
if not self._grid_map:
1048+
if not mapper:
10241049
return self
10251050

10261051
io = self.interp_order
1027-
# Base function
10281052
if self._avg_mode == 'harmonic':
1029-
retval = 1 / self.function
1053+
retval = 1 / self
10301054
else:
1031-
retval = self.function
1055+
retval = self
10321056

10331057
# Apply interpolation from inner most dim
1034-
for d, i in self._grid_map.items():
1058+
for d, i in mapper.items():
1059+
retval = retval._subs(i.subs(subs), self.indices_ref[d])
10351060
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10361061

10371062
# Evaluate. Since we used `self.function` it will be on the grid when
10381063
# evaluate is called again within FD
10391064
retval = retval._evaluate(**kwargs)
1065+
retval = retval.subs(subs)
10401066

10411067
# If harmonic averaging, invert at the end
10421068
if self._avg_mode == 'harmonic':
10431069
from devito.finite_differences.differentiable import SafeInv
1044-
retval = SafeInv(retval, self.function)
1070+
retval = SafeInv(retval, self.function.subs(subs))
10451071

10461072
return retval
10471073

devito/types/dense.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,12 +1073,16 @@ def _fd_priority(self):
10731073
def _eval_at(self, func):
10741074
if self.staggered == func.staggered:
10751075
return self
1076-
mapper = {self.indices_ref[d]: func.indices_ref[d]
1077-
for d in self.dimensions
1078-
if self.indices_ref[d] is not func.indices_ref[d]}
1079-
if mapper:
1080-
return self.subs(mapper)
1081-
return self
1076+
1077+
mapper = {}
1078+
for d in self.dimensions:
1079+
try:
1080+
if self.indices_ref[d] is not func.indices_ref[d]:
1081+
mapper[self.indices_ref[d]] = func.indices_ref[d]
1082+
except KeyError:
1083+
pass
1084+
1085+
return self.subs(mapper)
10821086

10831087
@classmethod
10841088
def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):

examples/userapi/07_functions_on_subdomains.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3003,7 +3003,7 @@
30033003
"metadata": {},
30043004
"outputs": [],
30053005
"source": [
3006-
"assert np.isclose(np.linalg.norm(rec.data), 4263.511, atol=0, rtol=1e-4)"
3006+
"assert np.isclose(np.linalg.norm(rec.data), 3640.584, atol=0, rtol=1e-4)"
30073007
]
30083008
}
30093009
],

tests/test_derivatives.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -805,13 +805,11 @@ def test_param_stagg_add(self):
805805
eq1 = Eq(vx, (c11 * txx).dy)
806806
eq2 = Eq(vx, (c11 * txx + c66 * txy).dy)
807807

808-
# C66 is a paramater. Expects to evaluate c66 at xp then the derivative at yp
809-
# and the derivative will interpolate txy at xp
808+
# Expects to evaluate c66 at xp then the derivative at yp
810809
expect0 = (c66.subs({x: xp, y: yp}).evaluate * txy).dy.evaluate
811810
assert simplify(eq0.evaluate.rhs - expect0) == 0
812811

813-
# C11 is a paramater and txy is staggered in x.
814-
# Expects to evaluate c11 and txy xp then the derivative at yp
812+
# Expects to evaluate c11 and txy at xp then the derivative at yp
815813
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
816814
assert simplify(eq1.evaluate.rhs - expect1) == 0
817815

tests/test_differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_shift():
5757
assert a.shift(x, x.spacing).shift(x, -x.spacing) == a
5858
assert a.shift(x, x.spacing).shift(x, x.spacing) == a.shift(x, 2*x.spacing)
5959
assert a.dx.evaluate.shift(x, x.spacing) == a.shift(x, x.spacing).dx.evaluate
60-
assert a.shift(x, .5 * x.spacing)._grid_map == {x: x + .5 * x.spacing}
60+
assert a.shift(x, .5 * x.spacing)._grid_map == {x: x + .5 * x.spacing, 'subs': {}}
6161

6262

6363
def test_interp():

tests/test_mpi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2895,9 +2895,9 @@ def test_staggering(self, mode):
28952895

28962896
op(time_M=2)
28972897
# Expected norms computed "manually" from sequential runs
2898-
assert np.isclose(norm(ux), 6054.139, rtol=1.e-4)
2899-
assert np.isclose(norm(uxx), 17814.95, rtol=1.e-4)
2900-
assert np.isclose(norm(uxy), 58712.22, rtol=1.e-4)
2898+
assert np.isclose(norm(ux), 7003.098, rtol=1.e-4)
2899+
assert np.isclose(norm(uxx), 78902.21, rtol=1.e-4)
2900+
assert np.isclose(norm(uxy), 71852.62, rtol=1.e-4)
29012901

29022902
@pytest.mark.parallel(mode=2)
29032903
def test_op_new_dist(self, mode):

tests/test_symbolics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ def test_is_on_grid():
743743
u = Function(name="u", grid=grid, space_order=2)
744744

745745
assert u._grid_map == {}
746-
assert u.subs({x: x0})._grid_map == {x: x0}
746+
assert u.subs({x: x0})._grid_map == {x: x0, 'subs': {}}
747747
assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate))
748748

749749

0 commit comments

Comments
 (0)