1313from devito .finite_differences .differentiable import IndexDerivative
1414from devito .ir import Cluster , Scope , cluster_pass
1515from devito .symbolics import estimate_cost , q_leaf , q_terminal
16- from devito .symbolics .search import retrieve_ctemps
16+ from devito .symbolics .search import search
1717from devito .symbolics .manipulation import _uxreplace
1818from devito .tools import DAG , as_list , as_tuple , frozendict , extract_dtype
1919from devito .types import Eq , Symbol , Temp
@@ -26,11 +26,15 @@ class CTemp(Temp):
2626 """
2727 A cluster-level Temp, similar to Temp, ensured to have different priority
2828 """
29- is_CTemp = True
3029
3130 ordering_of_classes .insert (ordering_of_classes .index ('Temp' ) + 1 , 'CTemp' )
3231
3332
33+ def retrieve_ctemps (exprs , mode = 'all' ):
34+ """Shorthand to retrieve the CTemps in `exprs`"""
35+ return search (exprs , lambda expr : isinstance (expr , CTemp ), mode , 'dfs' )
36+
37+
3438@cluster_pass
3539def cse (cluster , sregistry = None , options = None , ** kwargs ):
3640 """
@@ -229,12 +233,12 @@ def _compact(exprs, exclude):
229233 mapper = {e .lhs : e .rhs for e in candidates if q_leaf (e .rhs )}
230234
231235 # Find all the CTemps in expression right-hand-sides without removing duplicates
232- ctemps = retrieve_ctemps ([e .rhs for e in exprs ])
233- ctemp_count = Counter (ctemps )
236+ ctemps = retrieve_ctemps (e .rhs for e in exprs )
234237
235238 # If there are ctemps in the expressions, then add any to the mapper which only
236239 # appear once
237240 if ctemps :
241+ ctemp_count = Counter (ctemps )
238242 mapper .update ({e .lhs : e .rhs for e in candidates
239243 if ctemp_count [e .lhs ] == 1 })
240244
0 commit comments