Skip to content

Commit 7827262

Browse files
committed
compiler: Add buf-reuse opt-option
1 parent 822f146 commit 7827262

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _normalize_kwargs(cls, **kwargs):
3434

3535
# Buffering
3636
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
37+
o['buf-reuse'] = oo.pop('buf-reuse', None)
3738

3839
# Fusion
3940
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _normalize_kwargs(cls, **kwargs):
4444

4545
# Buffering
4646
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
47+
o['buf-reuse'] = oo.pop('buf-reuse', None)
4748

4849
# Fusion
4950
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

devito/passes/clusters/buffering.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,17 @@ def buffering(clusters, key, sregistry, options, **kwargs):
4444
ModuloDimensions. This might help relieving the synchronization
4545
overhead when asynchronous operations are used (these are however
4646
implemented by other passes).
47+
* 'buf-reuse': If True, the pass will try to reuse existing Buffers for
48+
different buffered Functions. By default, False.
4749
**kwargs
4850
Additional compilation options.
4951
Accepted: ['opt_init_onwrite', 'opt_buffer'].
5052
* 'opt_init_onwrite': By default, a written buffer does not trigger the
5153
generation of an initializing Cluster. With `opt_init_onwrite=True`,
5254
instead, the buffer gets initialized to zero.
55+
* 'opt_reuse': A callback that takes a buffering candidate `bf` as input
56+
and returns True if the pass can reuse pre-existing Buffers for
57+
buffering `bf`, which would otherwise default to False.
5358
* 'opt_buffer': A callback that takes a buffering candidate as input
5459
and returns a buffer, which would otherwise default to an Array.
5560
@@ -98,6 +103,7 @@ def key(f):
98103
options.update({
99104
'buf-init-onwrite': init_onwrite,
100105
'buf-callback': kwargs.get('opt_buffer'),
106+
'buf-reuse': kwargs.get('opt_reuse', options['buf-reuse']),
101107
})
102108

103109
# Escape hatch to selectively disable buffering
@@ -246,10 +252,13 @@ def callback(self, clusters, prefix):
246252
processed.append(Cluster(expr, ispace, guards, properties, syncs))
247253

248254
# Lift {write,read}-only buffers into separate IterationSpaces
249-
if self.options['fuse-tasks']:
250-
return init + processed
251-
else:
252-
return init + self._optimize(processed, descriptors)
255+
if not self.options['fuse-tasks']:
256+
processed = self._optimize(processed, descriptors)
257+
258+
if self.options['buf-reuse']:
259+
init, processed = self._reuse(init, processed, descriptors)
260+
261+
return init + processed
253262

254263
def _optimize(self, clusters, descriptors):
255264
for b, v in descriptors.items():
@@ -285,6 +294,48 @@ def _optimize(self, clusters, descriptors):
285294

286295
return clusters
287296

297+
def _reuse(self, init, clusters, descriptors):
298+
"""
299+
Reuse existing Buffers for buffering candidates.
300+
"""
301+
buf_reuse = self.options['buf-reuse']
302+
303+
if callable(buf_reuse):
304+
cbk = lambda v: [i for i in v if buf_reuse(descriptors[i].f)]
305+
else:
306+
cbk = lambda v: v
307+
308+
mapper = as_mapper(descriptors, key=lambda b: b._signature)
309+
mapper = {k: cbk(v) for k, v in mapper.items() if cbk(v)}
310+
311+
subs = {}
312+
drop = set()
313+
for reusable in mapper.values():
314+
retain = reusable.pop(0)
315+
drop.update(reusable)
316+
317+
name = self.sregistry.make_name(prefix='r')
318+
b = retain.func(name=name)
319+
320+
for i in (retain, *reusable):
321+
subs.update({i: b, i.indexed: b.indexed})
322+
323+
processed = []
324+
for c in init:
325+
if set(c.scope.writes) & drop:
326+
continue
327+
328+
exprs = [uxreplace(e, subs) for e in c.exprs]
329+
processed.append(c.rebuild(exprs=exprs))
330+
init = processed
331+
332+
processed = []
333+
for c in clusters:
334+
exprs = [uxreplace(e, subs) for e in c.exprs]
335+
processed.append(c.rebuild(exprs=exprs))
336+
337+
return init, processed
338+
288339

289340
Map = namedtuple('Map', 'b f')
290341

tests/test_buffering.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,3 +724,28 @@ def test_stencil_issue_1915_v2(subdomain):
724724
op1.apply(time_M=nt-2, u=u1)
725725

726726
assert np.all(u.data == u1.data)
727+
728+
729+
def test_buffer_reuse():
730+
nt = 10
731+
grid = Grid(shape=(4, 4))
732+
733+
u = TimeFunction(name='u', grid=grid)
734+
usave = TimeFunction(name='usave', grid=grid, save=nt)
735+
vsave = TimeFunction(name='vsave', grid=grid, save=nt)
736+
737+
eqns = [Eq(u.forward, u + 1),
738+
Eq(usave, u.forward),
739+
Eq(vsave, u.forward + 1)]
740+
741+
op = Operator(eqns, opt=('buffering', {'buf-reuse': True}))
742+
743+
# Check generated code
744+
assert len(retrieve_iteration_tree(op)) == 5
745+
buffers = [i for i in FindSymbols().visit(op) if i.is_Array and i._mem_heap]
746+
assert len(buffers) == 1
747+
748+
op.apply(time_M=nt-1)
749+
750+
assert all(np.all(usave.data[i-1] == i) for i in range(1, nt + 1))
751+
assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1))

0 commit comments

Comments
 (0)