Skip to content

Commit 4d4758d

Browse files
committed
compiler: Cache rcompile runs
1 parent 7827262 commit 4d4758d

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

devito/operator/operator.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from devito.operator.registry import operator_selector
2626
from devito.mpi import MPI
2727
from devito.parameters import configuration
28-
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
29-
generate_macros, minimize_symbols, unevaluate,
30-
error_mapper, is_on_device, lower_dtypes)
28+
from devito.passes import (
29+
Graph, lower_index_derivatives, generate_implicit, generate_macros,
30+
minimize_symbols, unevaluate, error_mapper, is_on_device, lower_dtypes
31+
)
3132
from devito.symbolics import estimate_cost, subs_op_args
3233
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3334
flatten, filter_sorted, frozendict, is_integer,
@@ -1150,6 +1151,35 @@ def __setstate__(self, state):
11501151
)
11511152

11521153

1154+
# *** Recursive compilation ("rcompile") machinery
1155+
1156+
1157+
class RCompiles(CacheInstances):
1158+
1159+
"""
1160+
A cache for abstract Callables obtained from lowering expressions.
1161+
Here, "abstract Callable" means that any user-level symbolic object appearing
1162+
in the input expressions is replaced by a corresponding abstract object.
1163+
"""
1164+
1165+
_instance_cache_size = None
1166+
1167+
def __init__(self, exprs, cls):
1168+
self.exprs = exprs
1169+
self.cls = cls
1170+
1171+
# NOTE: Constructed lazily at `__call__` time because `**kwargs` is
1172+
# unhashable for historical reasons (e.g., Compiler objects are mutable,
1173+
# though in practice they are unique per Operator, so only "locally"
1174+
# mutable)
1175+
self._output = None
1176+
1177+
def compile(self, **kwargs):
1178+
if self._output is None:
1179+
self._output = self.cls._lower(self.exprs, **kwargs)
1180+
return self._output
1181+
1182+
11531183
# Default action (perform or bypass) for selected compilation passes upon
11541184
# recursive compilation
11551185
# NOTE: it may not only be pointless to apply the following passes recursively
@@ -1167,6 +1197,7 @@ def rcompile(expressions, kwargs, options, target=None):
11671197
"""
11681198
Perform recursive compilation on an ordered sequence of symbolic expressions.
11691199
"""
1200+
expressions = as_tuple(expressions)
11701201
options = {**options, **rcompile_registry}
11711202

11721203
if target is None:
@@ -1181,10 +1212,14 @@ def rcompile(expressions, kwargs, options, target=None):
11811212
# Recursive profiling not supported -- would be a complete mess
11821213
kwargs.pop('profiler', None)
11831214

1184-
return cls._lower(expressions, **kwargs)
1215+
# Recursive compilation is expensive, so we cache the result because sometimes
1216+
# it is called multiple times for the same input
1217+
compiled = RCompiles(expressions, cls).compile(**kwargs)
1218+
1219+
return compiled
11851220

11861221

1187-
# Misc helpers
1222+
# *** Misc helpers
11881223

11891224

11901225
IRs = namedtuple('IRs', 'expressions clusters stree uiet iet')

0 commit comments

Comments
 (0)