Skip to content

Commit 5f680d4

Browse files
committed
compiler: Introduce Terminal mixin for SymPy subclasses
1 parent 582784e commit 5f680d4

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

devito/symbolics/extended_sympy.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from devito.types.basic import Basic
1919

2020
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'IntDiv', # noqa
21-
'CallFromPointer', 'CallFromComposite', 'FieldFromPointer',
21+
'Terminal', 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer',
2222
'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer',
2323
'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'Reserved',
2424
'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument',
@@ -129,6 +129,17 @@ def __mul__(self, other):
129129
return super().__mul__(other)
130130

131131

132+
class Terminal:
133+
134+
"""
135+
Abstract base class for all terminal objects, that is, those objects
136+
collected by `retrieve_terminals` in addition to all other SymPy atoms
137+
such as `Symbol`, `Number`, etc.
138+
"""
139+
140+
pass
141+
142+
132143
class BasicWrapperMixin:
133144

134145
"""
@@ -170,7 +181,7 @@ def _sympystr(self, printer):
170181
return str(self)
171182

172183

173-
class CallFromPointer(sympy.Expr, Pickable, BasicWrapperMixin):
184+
class CallFromPointer(Expr, Pickable, BasicWrapperMixin, Terminal):
174185

175186
"""
176187
Symbolic representation of the C notation ``pointer->call(params)``.
@@ -238,7 +249,7 @@ def free_symbols(self):
238249
__reduce_ex__ = Pickable.__reduce_ex__
239250

240251

241-
class CallFromComposite(CallFromPointer, Pickable):
252+
class CallFromComposite(CallFromPointer):
242253

243254
"""
244255
Symbolic representation of the C notation ``composite.call(params)``.
@@ -251,7 +262,7 @@ def __str__(self):
251262
__repr__ = __str__
252263

253264

254-
class FieldFromPointer(CallFromPointer, Pickable):
265+
class FieldFromPointer(CallFromPointer):
255266

256267
"""
257268
Symbolic representation of the C notation ``pointer->field``.
@@ -272,7 +283,7 @@ def field(self):
272283
__repr__ = __str__
273284

274285

275-
class FieldFromComposite(CallFromPointer, Pickable):
286+
class FieldFromComposite(CallFromPointer):
276287

277288
"""
278289
Symbolic representation of the C notation ``composite.field``,
@@ -334,7 +345,7 @@ def is_numeric(self):
334345
__reduce_ex__ = Pickable.__reduce_ex__
335346

336347

337-
class UnaryOp(sympy.Expr, Pickable, BasicWrapperMixin):
348+
class UnaryOp(Expr, Pickable, BasicWrapperMixin):
338349

339350
"""
340351
Symbolic representation of a unary C operator.
@@ -472,7 +483,7 @@ def __str__(self):
472483
return f"{self._op}{self.base}"
473484

474485

475-
class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin):
486+
class IndexedPointer(Expr, Pickable, BasicWrapperMixin, Terminal):
476487

477488
"""
478489
Symbolic representation of the C notation ``symbol[...]``

devito/symbolics/queries.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from sympy import Eq, IndexedBase, Mod, S, diff, nan
22

3-
from devito.symbolics.extended_sympy import (FieldFromComposite, FieldFromPointer,
4-
IndexedPointer, IntDiv)
3+
from devito.symbolics.extended_sympy import IntDiv, Terminal
54
from devito.tools import as_tuple, is_integer
65
from devito.types.basic import AbstractFunction
76
from devito.types.constant import Constant
@@ -16,13 +15,9 @@
1615
'q_dimension', 'q_positive', 'q_negative']
1716

1817

19-
# The following SymPy objects are considered tree leaves:
20-
#
21-
# * Number
22-
# * Symbol
23-
# * Indexed
24-
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject,
25-
IndexedPointer)
18+
# The following SymPy objects are considered tree leaves in addition to the classic
19+
# SymPy atoms such as Number, Symbol, Indexed, etc
20+
extra_leaves = (IndexedBase, AbstractObject, Terminal)
2621

2722

2823
def q_symbol(expr):

0 commit comments

Comments
 (0)