Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from devito.symbolics.unevaluation import Mul as UnevalMul
from devito.symbolics.unevaluation import Pow as UnevalPow
from devito.symbolics.unevaluation import UnevaluableMixin
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
from devito.tools import (
EnrichedTuple, as_list, as_tuple, flatten, split, transitive_closure
)
from devito.types.array import ComponentAccess
from devito.types.basic import Basic, Indexed
from devito.types.equation import Eq
Expand Down Expand Up @@ -130,6 +132,12 @@ def _(iterable, rule):
return iterable.__class__(ret), changed


@_uxreplace_dispatch.register(EnrichedTuple)
def _(iterable, rule):
retval, changed = _uxreplace_dispatch(tuple(iterable), rule)
return iterable.__class__(*retval, getters=iterable.getters), changed


@_uxreplace_dispatch.register(dict)
def _(mapper, rule):
ret = {}
Expand Down
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def origin(self):
@property
def dimensions(self):
"""Tuple of Dimensions representing the object indices."""
return self._dimensions
return DimensionTuple(*self._dimensions, getters=self._dimensions)

@cached_property
def space_dimensions(self):
Expand Down
3 changes: 2 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ def _eval_at(self, func):
for d in self.dimensions:
try:
if self.indices_ref[d] is not func.indices_ref[d]:
mapper[self.indices_ref[d]] = func.indices_ref[d]
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
mapper[self.indices_ref[d]] = f_idx
except KeyError:
pass

Expand Down
13 changes: 13 additions & 0 deletions tests/test_staggered_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,16 @@ def test_staggered_rebuild(stagg):
assert f2.indices[nd] == nd + nd.spacing / 2
else:
assert f2.indices[nd] == nd


def test_eval_at_different_dim():
grid = Grid(shape=(31, 17, 25))
nt = 5
x, _, _ = grid.dimensions

v = TimeFunction(name="v", grid=grid, staggered=x)
tau = TimeFunction(name="tau", grid=grid, save=nt)

eq = Eq(tau.forward, v).evaluate

assert grid.time_dim not in eq.rhs.free_symbols
Loading