2525from devito .operator .registry import operator_selector
2626from devito .mpi import MPI
2727from devito .parameters import configuration
28- from devito .passes import (Graph , lower_index_derivatives , generate_implicit ,
29- generate_macros , minimize_symbols , unevaluate ,
30- error_mapper , is_on_device , lower_dtypes )
28+ from devito .passes import (
29+ Graph , lower_index_derivatives , generate_implicit , generate_macros ,
30+ minimize_symbols , unevaluate , error_mapper , is_on_device , lower_dtypes
31+ )
3132from devito .symbolics import estimate_cost , subs_op_args
3233from devito .tools import (DAG , OrderedSet , Signer , ReducerMap , as_mapper , as_tuple ,
3334 flatten , filter_sorted , frozendict , is_integer ,
@@ -1150,6 +1151,35 @@ def __setstate__(self, state):
11501151 )
11511152
11521153
1154+ # *** Recursive compilation ("rcompile") machinery
1155+
1156+
1157+ class RCompiles (CacheInstances ):
1158+
1159+ """
1160+ A cache for abstract Callables obtained from lowering expressions.
1161+ Here, "abstract Callable" means that any user-level symbolic object appearing
1162+ in the input expressions is replaced by a corresponding abstract object.
1163+ """
1164+
1165+ _instance_cache_size = None
1166+
1167+ def __init__ (self , exprs , cls ):
1168+ self .exprs = exprs
1169+ self .cls = cls
1170+
1171+ # NOTE: Constructed lazily at `__call__` time because `**kwargs` is
1172+ # unhashable for historical reasons (e.g., Compiler objects are mutable,
1173+ # though in practice they are unique per Operator, so only "locally"
1174+ # mutable)
1175+ self ._output = None
1176+
1177+ def compile (self , ** kwargs ):
1178+ if self ._output is None :
1179+ self ._output = self .cls ._lower (self .exprs , ** kwargs )
1180+ return self ._output
1181+
1182+
11531183# Default action (perform or bypass) for selected compilation passes upon
11541184# recursive compilation
11551185# NOTE: it may not only be pointless to apply the following passes recursively
@@ -1167,6 +1197,7 @@ def rcompile(expressions, kwargs, options, target=None):
11671197 """
11681198 Perform recursive compilation on an ordered sequence of symbolic expressions.
11691199 """
1200+ expressions = as_tuple (expressions )
11701201 options = {** options , ** rcompile_registry }
11711202
11721203 if target is None :
@@ -1181,10 +1212,14 @@ def rcompile(expressions, kwargs, options, target=None):
11811212 # Recursive profiling not supported -- would be a complete mess
11821213 kwargs .pop ('profiler' , None )
11831214
1184- return cls ._lower (expressions , ** kwargs )
1215+ # Recursive compilation is expensive, so we cache the result because sometimes
1216+ # it is called multiple times for the same input
1217+ compiled = RCompiles (expressions , cls ).compile (** kwargs )
1218+
1219+ return compiled
11851220
11861221
1187- # Misc helpers
1222+ # *** Misc helpers
11881223
11891224
11901225IRs = namedtuple ('IRs' , 'expressions clusters stree uiet iet' )
0 commit comments