diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index d90c366bb2..57d9314e16 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -16,7 +16,9 @@ from devito.symbolics.unevaluation import Mul as UnevalMul from devito.symbolics.unevaluation import Pow as UnevalPow from devito.symbolics.unevaluation import UnevaluableMixin -from devito.tools import as_list, as_tuple, flatten, split, transitive_closure +from devito.tools import ( + EnrichedTuple, as_list, as_tuple, flatten, split, transitive_closure +) from devito.types.array import ComponentAccess from devito.types.basic import Basic, Indexed from devito.types.equation import Eq @@ -130,6 +132,12 @@ def _(iterable, rule): return iterable.__class__(ret), changed +@_uxreplace_dispatch.register(EnrichedTuple) +def _(iterable, rule): + retval, changed = _uxreplace_dispatch(tuple(iterable), rule) + return iterable.__class__(*retval, getters=iterable.getters), changed + + @_uxreplace_dispatch.register(dict) def _(mapper, rule): ret = {} diff --git a/devito/types/basic.py b/devito/types/basic.py index 75aed9d32f..a3b8af4533 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -971,7 +971,7 @@ def origin(self): @property def dimensions(self): """Tuple of Dimensions representing the object indices.""" - return self._dimensions + return DimensionTuple(*self._dimensions, getters=self._dimensions) @cached_property def space_dimensions(self): diff --git a/devito/types/dense.py b/devito/types/dense.py index 551dcad14f..ccee21bffb 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -1128,7 +1128,8 @@ def _eval_at(self, func): for d in self.dimensions: try: if self.indices_ref[d] is not func.indices_ref[d]: - mapper[self.indices_ref[d]] = func.indices_ref[d] + f_idx = func.indices_ref[d]._subs(func.dimensions[d], d) + mapper[self.indices_ref[d]] = f_idx except KeyError: pass diff --git a/tests/test_staggered_utils.py b/tests/test_staggered_utils.py index daec7f2108..9da58d34a4 100644 --- a/tests/test_staggered_utils.py +++ b/tests/test_staggered_utils.py @@ -187,3 +187,16 @@ def test_staggered_rebuild(stagg): assert f2.indices[nd] == nd + nd.spacing / 2 else: assert f2.indices[nd] == nd + + +def test_eval_at_different_dim(): + grid = Grid(shape=(31, 17, 25)) + nt = 5 + x, _, _ = grid.dimensions + + v = TimeFunction(name="v", grid=grid, staggered=x) + tau = TimeFunction(name="tau", grid=grid, save=nt) + + eq = Eq(tau.forward, v).evaluate + + assert grid.time_dim not in eq.rhs.free_symbols