@@ -44,12 +44,17 @@ def buffering(clusters, key, sregistry, options, **kwargs):
4444 ModuloDimensions. This might help relieving the synchronization
4545 overhead when asynchronous operations are used (these are however
4646 implemented by other passes).
47+ * 'buf-reuse': If True, the pass will try to reuse existing Buffers for
48+ different buffered Functions. By default, False.
4749 **kwargs
4850 Additional compilation options.
4951 Accepted: ['opt_init_onwrite', 'opt_buffer'].
5052 * 'opt_init_onwrite': By default, a written buffer does not trigger the
5153 generation of an initializing Cluster. With `opt_init_onwrite=True`,
5254 instead, the buffer gets initialized to zero.
55+ * 'opt_reuse': A callback that takes a buffering candidate `bf` as input
56+ and returns True if the pass can reuse pre-existing Buffers for
57+ buffering `bf`, which would otherwise default to False.
5358 * 'opt_buffer': A callback that takes a buffering candidate as input
5459 and returns a buffer, which would otherwise default to an Array.
5560
@@ -98,6 +103,7 @@ def key(f):
98103 options .update ({
99104 'buf-init-onwrite' : init_onwrite ,
100105 'buf-callback' : kwargs .get ('opt_buffer' ),
106+ 'buf-reuse' : kwargs .get ('opt_reuse' , options ['buf-reuse' ]),
101107 })
102108
103109 # Escape hatch to selectively disable buffering
@@ -246,10 +252,13 @@ def callback(self, clusters, prefix):
246252 processed .append (Cluster (expr , ispace , guards , properties , syncs ))
247253
248254 # Lift {write,read}-only buffers into separate IterationSpaces
249- if self .options ['fuse-tasks' ]:
250- return init + processed
251- else :
252- return init + self ._optimize (processed , descriptors )
255+ if not self .options ['fuse-tasks' ]:
256+ processed = self ._optimize (processed , descriptors )
257+
258+ if self .options ['buf-reuse' ]:
259+ init , processed = self ._reuse (init , processed , descriptors )
260+
261+ return init + processed
253262
254263 def _optimize (self , clusters , descriptors ):
255264 for b , v in descriptors .items ():
@@ -285,6 +294,48 @@ def _optimize(self, clusters, descriptors):
285294
286295 return clusters
287296
297+ def _reuse (self , init , clusters , descriptors ):
298+ """
299+ Reuse existing Buffers for buffering candidates.
300+ """
301+ buf_reuse = self .options ['buf-reuse' ]
302+
303+ if callable (buf_reuse ):
304+ cbk = lambda v : [i for i in v if buf_reuse (descriptors [i ].f )]
305+ else :
306+ cbk = lambda v : v
307+
308+ mapper = as_mapper (descriptors , key = lambda b : b ._signature )
309+ mapper = {k : cbk (v ) for k , v in mapper .items () if cbk (v )}
310+
311+ subs = {}
312+ drop = set ()
313+ for reusable in mapper .values ():
314+ retain = reusable .pop (0 )
315+ drop .update (reusable )
316+
317+ name = self .sregistry .make_name (prefix = 'r' )
318+ b = retain .func (name = name )
319+
320+ for i in (retain , * reusable ):
321+ subs .update ({i : b , i .indexed : b .indexed })
322+
323+ processed = []
324+ for c in init :
325+ if set (c .scope .writes ) & drop :
326+ continue
327+
328+ exprs = [uxreplace (e , subs ) for e in c .exprs ]
329+ processed .append (c .rebuild (exprs = exprs ))
330+ init = processed
331+
332+ processed = []
333+ for c in clusters :
334+ exprs = [uxreplace (e , subs ) for e in c .exprs ]
335+ processed .append (c .rebuild (exprs = exprs ))
336+
337+ return init , processed
338+
288339
289340Map = namedtuple ('Map' , 'b f' )
290341
0 commit comments