Skip to content

Commit fcea800

Browse files
committed
api: fix xreplace for enrichedtuple
1 parent 41db4f4 commit fcea800

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

devito/symbolics/manipulation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from devito.symbolics.unevaluation import Mul as UnevalMul
1717
from devito.symbolics.unevaluation import Pow as UnevalPow
1818
from devito.symbolics.unevaluation import UnevaluableMixin
19-
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
19+
from devito.tools import (
20+
EnrichedTuple, as_list, as_tuple, flatten, split, transitive_closure
21+
)
2022
from devito.types.array import ComponentAccess
2123
from devito.types.basic import Basic, Indexed
2224
from devito.types.equation import Eq
@@ -130,6 +132,12 @@ def _(iterable, rule):
130132
return iterable.__class__(ret), changed
131133

132134

135+
@_uxreplace_dispatch.register(EnrichedTuple)
136+
def _(iterable, rule):
137+
retval, changed = _uxreplace_dispatch(tuple(iterable), rule)
138+
return iterable.__class__(*retval, getters=iterable.getters), changed
139+
140+
133141
@_uxreplace_dispatch.register(dict)
134142
def _(mapper, rule):
135143
ret = {}

devito/types/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def __new__(cls, *args, **kwargs):
762762
# Initialization. The following attributes must be available
763763
# when executing __init_finalize__
764764
newobj._name = name
765-
newobj._dimensions = DimensionTuple(*dimensions, getters=dimensions)
765+
newobj._dimensions = dimensions
766766
newobj._shape = cls.__shape_setup__(**kwargs)
767767
newobj._dtype = cls.__dtype_setup__(**kwargs)
768768

@@ -971,7 +971,7 @@ def origin(self):
971971
@property
972972
def dimensions(self):
973973
"""Tuple of Dimensions representing the object indices."""
974-
return self._dimensions
974+
return DimensionTuple(*self._dimensions, getters=self._dimensions)
975975

976976
@cached_property
977977
def space_dimensions(self):

0 commit comments

Comments
 (0)