Skip to content

Commit b54ca3f

Browse files
authored
Merge pull request #2674 from devitocodes/min-buffer-mem
compiler: Add buf-reuse opt-option and rcompile caching
2 parents 822f146 + 4d4758d commit b54ca3f

File tree

5 files changed

+122
-9
lines changed

5 files changed

+122
-9
lines changed

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _normalize_kwargs(cls, **kwargs):
3434

3535
# Buffering
3636
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
37+
o['buf-reuse'] = oo.pop('buf-reuse', None)
3738

3839
# Fusion
3940
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _normalize_kwargs(cls, **kwargs):
4444

4545
# Buffering
4646
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
47+
o['buf-reuse'] = oo.pop('buf-reuse', None)
4748

4849
# Fusion
4950
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

devito/operator/operator.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from devito.operator.registry import operator_selector
2626
from devito.mpi import MPI
2727
from devito.parameters import configuration
28-
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
29-
generate_macros, minimize_symbols, unevaluate,
30-
error_mapper, is_on_device, lower_dtypes)
28+
from devito.passes import (
29+
Graph, lower_index_derivatives, generate_implicit, generate_macros,
30+
minimize_symbols, unevaluate, error_mapper, is_on_device, lower_dtypes
31+
)
3132
from devito.symbolics import estimate_cost, subs_op_args
3233
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3334
flatten, filter_sorted, frozendict, is_integer,
@@ -1150,6 +1151,35 @@ def __setstate__(self, state):
11501151
)
11511152

11521153

1154+
# *** Recursive compilation ("rcompile") machinery
1155+
1156+
1157+
class RCompiles(CacheInstances):
1158+
1159+
"""
1160+
A cache for abstract Callables obtained from lowering expressions.
1161+
Here, "abstract Callable" means that any user-level symbolic object appearing
1162+
in the input expressions is replaced by a corresponding abstract object.
1163+
"""
1164+
1165+
_instance_cache_size = None
1166+
1167+
def __init__(self, exprs, cls):
1168+
self.exprs = exprs
1169+
self.cls = cls
1170+
1171+
# NOTE: Constructed lazily at `__call__` time because `**kwargs` is
1172+
# unhashable for historical reasons (e.g., Compiler objects are mutable,
1173+
# though in practice they are unique per Operator, so only "locally"
1174+
# mutable)
1175+
self._output = None
1176+
1177+
def compile(self, **kwargs):
1178+
if self._output is None:
1179+
self._output = self.cls._lower(self.exprs, **kwargs)
1180+
return self._output
1181+
1182+
11531183
# Default action (perform or bypass) for selected compilation passes upon
11541184
# recursive compilation
11551185
# NOTE: it may not only be pointless to apply the following passes recursively
@@ -1167,6 +1197,7 @@ def rcompile(expressions, kwargs, options, target=None):
11671197
"""
11681198
Perform recursive compilation on an ordered sequence of symbolic expressions.
11691199
"""
1200+
expressions = as_tuple(expressions)
11701201
options = {**options, **rcompile_registry}
11711202

11721203
if target is None:
@@ -1181,10 +1212,14 @@ def rcompile(expressions, kwargs, options, target=None):
11811212
# Recursive profiling not supported -- would be a complete mess
11821213
kwargs.pop('profiler', None)
11831214

1184-
return cls._lower(expressions, **kwargs)
1215+
# Recursive compilation is expensive, so we cache the result because sometimes
1216+
# it is called multiple times for the same input
1217+
compiled = RCompiles(expressions, cls).compile(**kwargs)
1218+
1219+
return compiled
11851220

11861221

1187-
# Misc helpers
1222+
# *** Misc helpers
11881223

11891224

11901225
IRs = namedtuple('IRs', 'expressions clusters stree uiet iet')

devito/passes/clusters/buffering.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,17 @@ def buffering(clusters, key, sregistry, options, **kwargs):
4444
ModuloDimensions. This might help relieving the synchronization
4545
overhead when asynchronous operations are used (these are however
4646
implemented by other passes).
47+
* 'buf-reuse': If True, the pass will try to reuse existing Buffers for
48+
different buffered Functions. By default, False.
4749
**kwargs
4850
Additional compilation options.
4951
Accepted: ['opt_init_onwrite', 'opt_buffer'].
5052
* 'opt_init_onwrite': By default, a written buffer does not trigger the
5153
generation of an initializing Cluster. With `opt_init_onwrite=True`,
5254
instead, the buffer gets initialized to zero.
55+
* 'opt_reuse': A callback that takes a buffering candidate `bf` as input
56+
and returns True if the pass can reuse pre-existing Buffers for
57+
buffering `bf`, which would otherwise default to False.
5358
* 'opt_buffer': A callback that takes a buffering candidate as input
5459
and returns a buffer, which would otherwise default to an Array.
5560
@@ -98,6 +103,7 @@ def key(f):
98103
options.update({
99104
'buf-init-onwrite': init_onwrite,
100105
'buf-callback': kwargs.get('opt_buffer'),
106+
'buf-reuse': kwargs.get('opt_reuse', options['buf-reuse']),
101107
})
102108

103109
# Escape hatch to selectively disable buffering
@@ -246,10 +252,13 @@ def callback(self, clusters, prefix):
246252
processed.append(Cluster(expr, ispace, guards, properties, syncs))
247253

248254
# Lift {write,read}-only buffers into separate IterationSpaces
249-
if self.options['fuse-tasks']:
250-
return init + processed
251-
else:
252-
return init + self._optimize(processed, descriptors)
255+
if not self.options['fuse-tasks']:
256+
processed = self._optimize(processed, descriptors)
257+
258+
if self.options['buf-reuse']:
259+
init, processed = self._reuse(init, processed, descriptors)
260+
261+
return init + processed
253262

254263
def _optimize(self, clusters, descriptors):
255264
for b, v in descriptors.items():
@@ -285,6 +294,48 @@ def _optimize(self, clusters, descriptors):
285294

286295
return clusters
287296

297+
def _reuse(self, init, clusters, descriptors):
298+
"""
299+
Reuse existing Buffers for buffering candidates.
300+
"""
301+
buf_reuse = self.options['buf-reuse']
302+
303+
if callable(buf_reuse):
304+
cbk = lambda v: [i for i in v if buf_reuse(descriptors[i].f)]
305+
else:
306+
cbk = lambda v: v
307+
308+
mapper = as_mapper(descriptors, key=lambda b: b._signature)
309+
mapper = {k: cbk(v) for k, v in mapper.items() if cbk(v)}
310+
311+
subs = {}
312+
drop = set()
313+
for reusable in mapper.values():
314+
retain = reusable.pop(0)
315+
drop.update(reusable)
316+
317+
name = self.sregistry.make_name(prefix='r')
318+
b = retain.func(name=name)
319+
320+
for i in (retain, *reusable):
321+
subs.update({i: b, i.indexed: b.indexed})
322+
323+
processed = []
324+
for c in init:
325+
if set(c.scope.writes) & drop:
326+
continue
327+
328+
exprs = [uxreplace(e, subs) for e in c.exprs]
329+
processed.append(c.rebuild(exprs=exprs))
330+
init = processed
331+
332+
processed = []
333+
for c in clusters:
334+
exprs = [uxreplace(e, subs) for e in c.exprs]
335+
processed.append(c.rebuild(exprs=exprs))
336+
337+
return init, processed
338+
288339

289340
Map = namedtuple('Map', 'b f')
290341

tests/test_buffering.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,3 +724,28 @@ def test_stencil_issue_1915_v2(subdomain):
724724
op1.apply(time_M=nt-2, u=u1)
725725

726726
assert np.all(u.data == u1.data)
727+
728+
729+
def test_buffer_reuse():
730+
nt = 10
731+
grid = Grid(shape=(4, 4))
732+
733+
u = TimeFunction(name='u', grid=grid)
734+
usave = TimeFunction(name='usave', grid=grid, save=nt)
735+
vsave = TimeFunction(name='vsave', grid=grid, save=nt)
736+
737+
eqns = [Eq(u.forward, u + 1),
738+
Eq(usave, u.forward),
739+
Eq(vsave, u.forward + 1)]
740+
741+
op = Operator(eqns, opt=('buffering', {'buf-reuse': True}))
742+
743+
# Check generated code
744+
assert len(retrieve_iteration_tree(op)) == 5
745+
buffers = [i for i in FindSymbols().visit(op) if i.is_Array and i._mem_heap]
746+
assert len(buffers) == 1
747+
748+
op.apply(time_M=nt-1)
749+
750+
assert all(np.all(usave.data[i-1] == i) for i in range(1, nt + 1))
751+
assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1))

0 commit comments

Comments
 (0)