Skip to content

Commit 5423002

Browse files
committed
tweak the definition of 'multiple successors' in MPMS materializer to handle indexing with heavy reuse
1 parent e5dcca7 commit 5423002

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

pytato/transform/materialize.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def add(
172172

173173

174174
def _materialize_if_mpms(expr: Array,
175-
nsuccessors: int,
175+
successors: list[ArrayOrNames],
176176
predecessors: Iterable[MPMSMaterializerAccumulator]
177177
) -> MPMSMaterializerAccumulator:
178178
"""
@@ -189,6 +189,17 @@ def _materialize_if_mpms(expr: Array,
189189
(pred.materialized_predecessors for pred in predecessors),
190190
cast("frozenset[Array]", frozenset()))
191191

192+
nsuccessors = 0
193+
for successor in successors:
194+
# Handle indexing with heavy reuse, if the sizes are known ahead of time
195+
if (
196+
isinstance(successor, IndexBase)
197+
and isinstance(successor.size, int)
198+
and isinstance(expr.size, int)):
199+
nsuccessors += (successor.size // expr.size) if expr.size else 0
200+
else:
201+
nsuccessors += 1
202+
192203
if nsuccessors > 1 and len(materialized_predecessors) > 1:
193204
new_expr = expr.tagged(ImplStored())
194205
return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr)
@@ -201,14 +212,15 @@ class MPMSMaterializer(
201212
"""
202213
See :func:`materialize_with_mpms` for an explanation.
203214
204-
.. attribute:: nsuccessors
215+
.. attribute:: successors
205216
206217
A mapping from a node in the expression graph (i.e. an
207-
:class:`~pytato.Array`) to its number of successors.
218+
:class:`~pytato.Array`) to a list of its successors (possibly including
219+
multiple references to the same successor if it uses the node multiple times).
208220
"""
209221
def __init__(
210222
self,
211-
nsuccessors: Mapping[Array, int],
223+
successors: Mapping[Array, list[ArrayOrNames]],
212224
_cache: MPMSMaterializerCache | None = None):
213225
err_on_collision = __debug__
214226
err_on_created_duplicate = __debug__
@@ -221,7 +233,7 @@ def __init__(
221233
# Does not support functions, so function_cache is ignored
222234
super().__init__(err_on_collision=err_on_collision, _cache=_cache)
223235

224-
self.nsuccessors: Mapping[Array, int] = nsuccessors
236+
self.successors: Mapping[Array, list[ArrayOrNames]] = successors
225237

226238
@override
227239
def _cache_add(
@@ -269,38 +281,38 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator:
269281
for bnd_name, bnd in children_rec.items()})
270282
return _materialize_if_mpms(
271283
expr.replace_if_different(bindings=new_children),
272-
self.nsuccessors[expr],
284+
self.successors[expr],
273285
children_rec.values())
274286

275287
def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator:
276288
rec_arrays = [self.rec(ary) for ary in expr.arrays]
277289
new_arrays = tuple(ary.expr for ary in rec_arrays)
278290
return _materialize_if_mpms(
279291
expr.replace_if_different(arrays=new_arrays),
280-
self.nsuccessors[expr],
292+
self.successors[expr],
281293
rec_arrays)
282294

283295
def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator:
284296
rec_arrays = [self.rec(ary) for ary in expr.arrays]
285297
new_arrays = tuple(ary.expr for ary in rec_arrays)
286298
return _materialize_if_mpms(
287299
expr.replace_if_different(arrays=new_arrays),
288-
self.nsuccessors[expr],
300+
self.successors[expr],
289301
rec_arrays)
290302

291303
def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator:
292304
rec_array = self.rec(expr.array)
293305
return _materialize_if_mpms(
294306
expr.replace_if_different(array=rec_array.expr),
295-
self.nsuccessors[expr],
307+
self.successors[expr],
296308
(rec_array,))
297309

298310
def map_axis_permutation(self, expr: AxisPermutation
299311
) -> MPMSMaterializerAccumulator:
300312
rec_array = self.rec(expr.array)
301313
return _materialize_if_mpms(
302314
expr.replace_if_different(array=rec_array.expr),
303-
self.nsuccessors[expr],
315+
self.successors[expr],
304316
(rec_array,))
305317

306318
def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
@@ -319,7 +331,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
319331
else new_indices)
320332
return _materialize_if_mpms(
321333
expr.replace_if_different(array=rec_array.expr, indices=new_indices),
322-
self.nsuccessors[expr],
334+
self.successors[expr],
323335
(rec_array, *tuple(rec_indices.values())))
324336

325337
def map_basic_index(self, expr: BasicIndex) -> MPMSMaterializerAccumulator:
@@ -338,15 +350,15 @@ def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator:
338350
rec_array = self.rec(expr.array)
339351
return _materialize_if_mpms(
340352
expr.replace_if_different(array=rec_array.expr),
341-
self.nsuccessors[expr],
353+
self.successors[expr],
342354
(rec_array,))
343355

344356
def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator:
345357
rec_args = [self.rec(ary) for ary in expr.args]
346358
new_args = tuple(ary.expr for ary in rec_args)
347359
return _materialize_if_mpms(
348360
expr.replace_if_different(args=new_args),
349-
self.nsuccessors[expr],
361+
self.successors[expr],
350362
rec_args)
351363

352364
def map_dict_of_named_arrays(self, expr: DictOfNamedArrays
@@ -427,8 +439,8 @@ def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
427439
====== ======== =======
428440
429441
"""
430-
from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers
431-
materializer = MPMSMaterializer(get_nusers(expr))
442+
from pytato.analysis import get_list_of_users, get_num_nodes, get_num_tags_of_type
443+
materializer = MPMSMaterializer(get_list_of_users(expr))
432444

433445
if isinstance(expr, Array):
434446
res = materializer(expr).expr

0 commit comments

Comments
 (0)