Skip to content

Commit c725b2f

Browse files
EdCauntmloubout
authored andcommitted
misc: Refactoring and misc code style improvments
1 parent ad69ef7 commit c725b2f

File tree

6 files changed

+12
-21
lines changed

6 files changed

+12
-21
lines changed

devito/ir/support/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def lex_le(self, other):
317317
def lex_lt(self, other):
318318
return self.timestamp < other.timestamp
319319

320-
# Note: memoization yields mild compiler speedup
320+
# Note: memoization yields mild compiler speedup. Will need to be made
321+
# thread-safe for multithreading the compiler.
321322
@memoized_meth
322323
def distance(self, other):
323324
"""
@@ -369,7 +370,7 @@ def distance(self, other):
369370
# To avoid evaluating expensive symbolic Lt or Gt operations,
370371
# we pre-empt such operations by checking if the values to be compared
371372
# to are symbolic, and skip this case if not.
372-
if not any(isinstance(i, sympy.core.Basic)
373+
if not any(isinstance(i, sympy.Basic)
373374
for i in (sit.symbolic_min, sit.symbolic_max)):
374375

375376
for v in (self[n], other[n]):

devito/passes/clusters/cse.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.finite_differences.differentiable import IndexDerivative
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
16-
from devito.symbolics.search import retrieve_ctemps
16+
from devito.symbolics.search import search
1717
from devito.symbolics.manipulation import _uxreplace
1818
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1919
from devito.types import Eq, Symbol, Temp
@@ -26,11 +26,15 @@ class CTemp(Temp):
2626
"""
2727
A cluster-level Temp, similar to Temp, ensured to have different priority
2828
"""
29-
is_CTemp = True
3029

3130
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')
3231

3332

33+
def retrieve_ctemps(exprs, mode='all'):
34+
"""Shorthand to retrieve the CTemps in `exprs`"""
35+
return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs')
36+
37+
3438
@cluster_pass
3539
def cse(cluster, sregistry=None, options=None, **kwargs):
3640
"""
@@ -229,12 +233,12 @@ def _compact(exprs, exclude):
229233
mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)}
230234

231235
# Find all the CTemps in expression right-hand-sides without removing duplicates
232-
ctemps = retrieve_ctemps([e.rhs for e in exprs])
233-
ctemp_count = Counter(ctemps)
236+
ctemps = retrieve_ctemps(e.rhs for e in exprs)
234237

235238
# If there are ctemps in the expressions, then add any to the mapper which only
236239
# appear once
237240
if ctemps:
241+
ctemp_count = Counter(ctemps)
238242
mapper.update({e.lhs: e.rhs for e in candidates
239243
if ctemp_count[e.lhs] == 1})
240244

devito/passes/clusters/misc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def is_cross(source, sink):
361361
# (intuitively, "the loop nests are to be kept separated")
362362
# * All ClusterGroups between `cg0` and `cg1` must precede `cg1`
363363
# * All ClusterGroups after `cg1` cannot precede `cg1`
364-
365364
if any(i.cause & prefix for i in scope.d_anti_gen()):
366365
for cg2 in cgroups[n:cgroups.index(cg1)]:
367366
dag.add_edge(cg2, cg1)

devito/symbolics/queries.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ def q_symbol(expr):
3232
return False
3333

3434

35-
def q_ctemp(expr):
36-
try:
37-
return expr.is_CTemp
38-
except AttributeError:
39-
return False
40-
41-
4235
def q_comp_acc(expr):
4336
return isinstance(expr, ComponentAccess)
4437

devito/symbolics/search.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sympy
22

33
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
4-
q_symbol, q_ctemp, q_dimension, q_derivative)
4+
q_symbol, q_dimension, q_derivative)
55
from devito.tools import as_tuple
66

77
__all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers',
@@ -159,11 +159,6 @@ def retrieve_symbols(exprs, mode='all'):
159159
return search(exprs, q_symbol, mode, 'dfs')
160160

161161

162-
def retrieve_ctemps(exprs, mode='all'):
163-
"""Shorthand to retrieve the CTemps in `exprs`"""
164-
return search(exprs, q_ctemp, mode, 'dfs')
165-
166-
167162
def retrieve_function_carriers(exprs, mode='all'):
168163
"""
169164
Shorthand to retrieve the DiscreteFunction carriers in ``exprs``. An

devito/types/basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ class Basic(CodeSymbol):
298298
is_Object = False
299299
is_LocalObject = False
300300
is_LocalType = False
301-
is_CTemp = False
302301

303302
# Created by the user
304303
is_Input = False

0 commit comments

Comments
 (0)