From 41db4f4fa78ff9ee5d02d08d97947cba3105802b Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 21 Jan 2026 08:05:14 -0500 Subject: [PATCH 1/2] api: fix evaluation with different time dims --- devito/types/basic.py | 2 +- devito/types/dense.py | 3 ++- tests/test_staggered_utils.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/devito/types/basic.py b/devito/types/basic.py index 75aed9d32f..2eca46d3d1 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -762,7 +762,7 @@ def __new__(cls, *args, **kwargs): # Initialization. The following attributes must be available # when executing __init_finalize__ newobj._name = name - newobj._dimensions = dimensions + newobj._dimensions = DimensionTuple(*dimensions, getters=dimensions) newobj._shape = cls.__shape_setup__(**kwargs) newobj._dtype = cls.__dtype_setup__(**kwargs) diff --git a/devito/types/dense.py b/devito/types/dense.py index 551dcad14f..ccee21bffb 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -1128,7 +1128,8 @@ def _eval_at(self, func): for d in self.dimensions: try: if self.indices_ref[d] is not func.indices_ref[d]: - mapper[self.indices_ref[d]] = func.indices_ref[d] + f_idx = func.indices_ref[d]._subs(func.dimensions[d], d) + mapper[self.indices_ref[d]] = f_idx except KeyError: pass diff --git a/tests/test_staggered_utils.py b/tests/test_staggered_utils.py index daec7f2108..9da58d34a4 100644 --- a/tests/test_staggered_utils.py +++ b/tests/test_staggered_utils.py @@ -187,3 +187,16 @@ def test_staggered_rebuild(stagg): assert f2.indices[nd] == nd + nd.spacing / 2 else: assert f2.indices[nd] == nd + + +def test_eval_at_different_dim(): + grid = Grid(shape=(31, 17, 25)) + nt = 5 + x, _, _ = grid.dimensions + + v = TimeFunction(name="v", grid=grid, staggered=x) + tau = TimeFunction(name="tau", grid=grid, save=nt) + + eq = Eq(tau.forward, v).evaluate + + assert grid.time_dim not in eq.rhs.free_symbols From fcea80066352124fe8a942baebdba9135040dc8d Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 21 Jan 2026 09:08:11 -0500 Subject: [PATCH 2/2] api: fix xreplace for enrichedtuple --- devito/symbolics/manipulation.py | 10 +++++++++- devito/types/basic.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index d90c366bb2..57d9314e16 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -16,7 +16,9 @@ from devito.symbolics.unevaluation import Mul as UnevalMul from devito.symbolics.unevaluation import Pow as UnevalPow from devito.symbolics.unevaluation import UnevaluableMixin -from devito.tools import as_list, as_tuple, flatten, split, transitive_closure +from devito.tools import ( + EnrichedTuple, as_list, as_tuple, flatten, split, transitive_closure +) from devito.types.array import ComponentAccess from devito.types.basic import Basic, Indexed from devito.types.equation import Eq @@ -130,6 +132,12 @@ def _(iterable, rule): return iterable.__class__(ret), changed +@_uxreplace_dispatch.register(EnrichedTuple) +def _(iterable, rule): + retval, changed = _uxreplace_dispatch(tuple(iterable), rule) + return iterable.__class__(*retval, getters=iterable.getters), changed + + @_uxreplace_dispatch.register(dict) def _(mapper, rule): ret = {} diff --git a/devito/types/basic.py b/devito/types/basic.py index 2eca46d3d1..a3b8af4533 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -762,7 +762,7 @@ def __new__(cls, *args, **kwargs): # Initialization. The following attributes must be available # when executing __init_finalize__ newobj._name = name - newobj._dimensions = DimensionTuple(*dimensions, getters=dimensions) + newobj._dimensions = dimensions newobj._shape = cls.__shape_setup__(**kwargs) newobj._dtype = cls.__dtype_setup__(**kwargs) @@ -971,7 +971,7 @@ def origin(self): @property def dimensions(self): """Tuple of Dimensions representing the object indices.""" - return self._dimensions + return DimensionTuple(*self._dimensions, getters=self._dimensions) @cached_property def space_dimensions(self):