Skip to content

Commit b89c74b

Browse files
committed
api: fix staggered evaluation for add/mul
1 parent 183bbef commit b89c74b

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

devito/finite_differences/derivative.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,12 @@ def _eval_at(self, func):
530530
# it into `u(x + h_x/2).dx` and `v(x).dx`, since they require
531531
# different FD indices
532532
mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
533+
if len(mapper) == 1:
534+
# All terms have the same staggering, we can use expr as is
535+
return self._rebuild(self.expr, **rkw)
533536
args = [self.expr.func(*v) for v in mapper.values()]
534537
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
535-
args = [self._rebuild(a, **rkw) for a in args]
538+
args = [self._rebuild(a)._eval_at(func) for a in args]
536539
return self.expr.func(*args)
537540
elif self.expr.is_Mul:
538541
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear

devito/finite_differences/differentiable.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,13 @@ def _(expr, x0, **kwargs):
11481148
from devito.finite_differences.derivative import Derivative
11491149
x0_expr = {d: v for d, v in x0.items() if v is not expr.indices_ref[d]}
11501150
if expr.is_parameter:
1151+
# Parameter might not have been evaluated at x0 yet
1152+
# E.g., in expressions such as u[x]*f[x] evaluated at x+dx/2
1153+
# `f` will not be evaluated at it since gather_for_diff will pick `u` as
1154+
# higher priority function.
1155+
for d, v in x0.items():
1156+
if expr.indices[d] is d:
1157+
expr = expr._subs(d, v)
11511158
return expr
11521159
elif x0_expr:
11531160
dims = tuple((d, 0) for d in x0_expr)

tests/test_derivatives.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,44 @@ def test_param_stagg_inner(self):
780780
eqne = eqn.evaluate.rhs
781781
assert simplify(eqne - (p._subs(y, yp).evaluate * f).dx(x0=xp).evaluate) == 0
782782

783+
def test_param_stagg_add(self):
784+
space_order = 2
785+
nx, ny = 5, 5
786+
extent = (nx-1, ny-1)
787+
788+
grid = Grid(shape=(nx, ny), extent=extent)
789+
x, y = grid.dimensions
790+
yp = y + y.spacing / 2
791+
xp = x + x.spacing / 2
792+
793+
x, y = grid.dimensions
794+
795+
vx = TimeFunction(name="vx", grid=grid, space_order=space_order,
796+
time_order1=1, staggered=x)
797+
txx = TimeFunction(name="txx", grid=grid, space_order=space_order,
798+
time_order=1, staggered=NODE)
799+
txy = TimeFunction(name="txy", grid=grid, space_order=space_order,
800+
time_order=1, staggered=(x, y))
801+
c11 = Function(name="c11", grid=grid, space_order=space_order, parameter=True)
802+
c66 = Function(name="c66", grid=grid, space_order=space_order, parameter=True)
803+
804+
eq0 = Eq(vx, (c66 * txy).dy)
805+
eq1 = Eq(vx, (c11 * txx).dy)
806+
eq2 = Eq(vx, (c11 * txx + c66 * txy).dy)
807+
808+
# C66 is a paramater. Expects to evaluate c66 at xp then the derivative at yp
809+
# and the derivative will interpolate txy at xp
810+
expect0 = (c66.subs({x: xp, y: yp}).evaluate * txy).dy.evaluate
811+
assert simplify(eq0.evaluate.rhs - expect0) == 0
812+
813+
# C11 is a paramater and txy is staggered in x.
814+
# Expects to evaluate c11 and txy xp then the derivative at yp
815+
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
816+
assert simplify(eq1.evaluate.rhs - expect1) == 0
817+
818+
# Addition should apply the same logic as above for each term
819+
assert simplify(eq2.evaluate.rhs - (expect1 + expect0)) == 0
820+
783821

784822
class TestTwoStageEvaluation:
785823

tests/test_staggered_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_gather_for_diff(expr, expected):
9999

100100
@pytest.mark.parametrize('expr, expected', [
101101
('((a + b).dx._eval_at(a)).is_Add', 'True'),
102-
('(a + b).dx._eval_at(a)', 'a.dx(x0=a.indices_ref.getters) + b.dx._eval_at(a)'),
102+
('(a + b).dx._eval_at(a)', 'a.dx + b.dx._eval_at(a)'),
103103
('(a*b).dx._eval_at(a).expr', 'a.subs({x: x0}) * b'),
104104
('(a * b.dx).dx._eval_at(b).expr._eval_deriv ',
105105
'a.subs({x: x0}) * b.dx.evaluate')])

0 commit comments

Comments
 (0)