Skip to content

Commit b73ee7b

Browse files
majosminducer
authored andcommitted
tweak the definition of 'multiple successors' in MPMS materializer to handle indexing with heavy reuse
add more explanation for MPMS reuse tweak
1 parent 5394104 commit b73ee7b

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

pytato/transform/materialize.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def add(
172172

173173

174174
def _materialize_if_mpms(expr: Array,
175-
nsuccessors: int,
175+
successors: list[ArrayOrNames],
176176
predecessors: Iterable[MPMSMaterializerAccumulator]
177177
) -> MPMSMaterializerAccumulator:
178178
"""
@@ -189,6 +189,24 @@ def _materialize_if_mpms(expr: Array,
189189
(pred.materialized_predecessors for pred in predecessors),
190190
cast("frozenset[Array]", frozenset()))
191191

192+
nsuccessors = 0
193+
for successor in successors:
194+
# Handle indexing with heavy reuse, if the sizes are known ahead of time.
195+
# This can occur when the elements of a smaller array are used repeatedly to
196+
# compute the elements of a larger array. (Example: In meshmode's direct
197+
# connection code, this happens when injecting data from a smaller
198+
# discretization into a larger one, such as BTAG_ALL -> FACE_RESTR_ALL.)
199+
#
200+
# In this case, we would like to bias towards materialization by
201+
# making one successor seem like n of them, if it is n times bigger.
202+
if (
203+
isinstance(successor, IndexBase)
204+
and isinstance(successor.size, int)
205+
and isinstance(expr.size, int)):
206+
nsuccessors += (successor.size // expr.size) if expr.size else 0
207+
else:
208+
nsuccessors += 1
209+
192210
if nsuccessors > 1 and len(materialized_predecessors) > 1:
193211
new_expr = expr.tagged(ImplStored())
194212
return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr)
@@ -201,14 +219,15 @@ class MPMSMaterializer(
201219
"""
202220
See :func:`materialize_with_mpms` for an explanation.
203221
204-
.. attribute:: nsuccessors
222+
.. attribute:: successors
205223
206224
A mapping from a node in the expression graph (i.e. an
207-
:class:`~pytato.Array`) to its number of successors.
225+
:class:`~pytato.Array`) to a list of its successors (possibly including
226+
multiple references to the same successor if it uses the node multiple times).
208227
"""
209228
def __init__(
210229
self,
211-
nsuccessors: Mapping[Array, int],
230+
successors: Mapping[Array, list[ArrayOrNames]],
212231
_cache: MPMSMaterializerCache | None = None):
213232
err_on_collision = __debug__
214233
err_on_created_duplicate = __debug__
@@ -221,7 +240,7 @@ def __init__(
221240
# Does not support functions, so function_cache is ignored
222241
super().__init__(err_on_collision=err_on_collision, _cache=_cache)
223242

224-
self.nsuccessors: Mapping[Array, int] = nsuccessors
243+
self.successors: Mapping[Array, list[ArrayOrNames]] = successors
225244

226245
@override
227246
def _cache_add(
@@ -269,38 +288,38 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator:
269288
for bnd_name, bnd in children_rec.items()})
270289
return _materialize_if_mpms(
271290
expr.replace_if_different(bindings=new_children),
272-
self.nsuccessors[expr],
291+
self.successors[expr],
273292
children_rec.values())
274293

275294
def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator:
276295
rec_arrays = [self.rec(ary) for ary in expr.arrays]
277296
new_arrays = tuple(ary.expr for ary in rec_arrays)
278297
return _materialize_if_mpms(
279298
expr.replace_if_different(arrays=new_arrays),
280-
self.nsuccessors[expr],
299+
self.successors[expr],
281300
rec_arrays)
282301

283302
def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator:
284303
rec_arrays = [self.rec(ary) for ary in expr.arrays]
285304
new_arrays = tuple(ary.expr for ary in rec_arrays)
286305
return _materialize_if_mpms(
287306
expr.replace_if_different(arrays=new_arrays),
288-
self.nsuccessors[expr],
307+
self.successors[expr],
289308
rec_arrays)
290309

291310
def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator:
292311
rec_array = self.rec(expr.array)
293312
return _materialize_if_mpms(
294313
expr.replace_if_different(array=rec_array.expr),
295-
self.nsuccessors[expr],
314+
self.successors[expr],
296315
(rec_array,))
297316

298317
def map_axis_permutation(self, expr: AxisPermutation
299318
) -> MPMSMaterializerAccumulator:
300319
rec_array = self.rec(expr.array)
301320
return _materialize_if_mpms(
302321
expr.replace_if_different(array=rec_array.expr),
303-
self.nsuccessors[expr],
322+
self.successors[expr],
304323
(rec_array,))
305324

306325
def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
@@ -319,7 +338,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
319338
else new_indices)
320339
return _materialize_if_mpms(
321340
expr.replace_if_different(array=rec_array.expr, indices=new_indices),
322-
self.nsuccessors[expr],
341+
self.successors[expr],
323342
(rec_array, *tuple(rec_indices.values())))
324343

325344
def map_basic_index(self, expr: BasicIndex) -> MPMSMaterializerAccumulator:
@@ -338,15 +357,15 @@ def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator:
338357
rec_array = self.rec(expr.array)
339358
return _materialize_if_mpms(
340359
expr.replace_if_different(array=rec_array.expr),
341-
self.nsuccessors[expr],
360+
self.successors[expr],
342361
(rec_array,))
343362

344363
def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator:
345364
rec_args = [self.rec(ary) for ary in expr.args]
346365
new_args = tuple(ary.expr for ary in rec_args)
347366
return _materialize_if_mpms(
348367
expr.replace_if_different(args=new_args),
349-
self.nsuccessors[expr],
368+
self.successors[expr],
350369
rec_args)
351370

352371
def map_dict_of_named_arrays(self, expr: DictOfNamedArrays
@@ -427,8 +446,8 @@ def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
427446
====== ======== =======
428447
429448
"""
430-
from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers
431-
materializer = MPMSMaterializer(get_nusers(expr))
449+
from pytato.analysis import get_list_of_users, get_num_nodes, get_num_tags_of_type
450+
materializer = MPMSMaterializer(get_list_of_users(expr))
432451

433452
if isinstance(expr, Array):
434453
res = materializer(expr).expr

0 commit comments

Comments
 (0)