@@ -172,7 +172,7 @@ def add(
172172
173173
174174def _materialize_if_mpms (expr : Array ,
175- nsuccessors : int ,
175+ successors : list [ ArrayOrNames ] ,
176176 predecessors : Iterable [MPMSMaterializerAccumulator ]
177177 ) -> MPMSMaterializerAccumulator :
178178 """
@@ -189,6 +189,24 @@ def _materialize_if_mpms(expr: Array,
189189 (pred .materialized_predecessors for pred in predecessors ),
190190 cast ("frozenset[Array]" , frozenset ()))
191191
192+ nsuccessors = 0
193+ for successor in successors :
194+ # Handle indexing with heavy reuse, if the sizes are known ahead of time.
195+ # This can occur when the elements of a smaller array are used repeatedly to
196+ # compute the elements of a larger array. (Example: In meshmode's direct
197+ # connection code, this happens when injecting data from a smaller
198+ # discretization into a larger one, such as BTAG_ALL -> FACE_RESTR_ALL.)
199+ #
200+ # In this case, we would like to bias towards materialization by
201+ # making one successor seem like n of them, if it is n times bigger.
202+ if (
203+ isinstance (successor , IndexBase )
204+ and isinstance (successor .size , int )
205+ and isinstance (expr .size , int )):
206+ nsuccessors += (successor .size // expr .size ) if expr .size else 0
207+ else :
208+ nsuccessors += 1
209+
192210 if nsuccessors > 1 and len (materialized_predecessors ) > 1 :
193211 new_expr = expr .tagged (ImplStored ())
194212 return MPMSMaterializerAccumulator (frozenset ([new_expr ]), new_expr )
@@ -201,14 +219,15 @@ class MPMSMaterializer(
201219 """
202220 See :func:`materialize_with_mpms` for an explanation.
203221
204- .. attribute:: nsuccessors
222+ .. attribute:: successors
205223
206224 A mapping from a node in the expression graph (i.e. an
207- :class:`~pytato.Array`) to its number of successors.
225+ :class:`~pytato.Array`) to a list of its successors (possibly including
226+ multiple references to the same successor if it uses the node multiple times).
208227 """
209228 def __init__ (
210229 self ,
211- nsuccessors : Mapping [Array , int ],
230+ successors : Mapping [Array , list [ ArrayOrNames ] ],
212231 _cache : MPMSMaterializerCache | None = None ):
213232 err_on_collision = __debug__
214233 err_on_created_duplicate = __debug__
@@ -221,7 +240,7 @@ def __init__(
221240 # Does not support functions, so function_cache is ignored
222241 super ().__init__ (err_on_collision = err_on_collision , _cache = _cache )
223242
224- self .nsuccessors : Mapping [Array , int ] = nsuccessors
243+ self .successors : Mapping [Array , list [ ArrayOrNames ]] = successors
225244
226245 @override
227246 def _cache_add (
@@ -269,38 +288,38 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator:
269288 for bnd_name , bnd in children_rec .items ()})
270289 return _materialize_if_mpms (
271290 expr .replace_if_different (bindings = new_children ),
272- self .nsuccessors [expr ],
291+ self .successors [expr ],
273292 children_rec .values ())
274293
275294 def map_stack (self , expr : Stack ) -> MPMSMaterializerAccumulator :
276295 rec_arrays = [self .rec (ary ) for ary in expr .arrays ]
277296 new_arrays = tuple (ary .expr for ary in rec_arrays )
278297 return _materialize_if_mpms (
279298 expr .replace_if_different (arrays = new_arrays ),
280- self .nsuccessors [expr ],
299+ self .successors [expr ],
281300 rec_arrays )
282301
283302 def map_concatenate (self , expr : Concatenate ) -> MPMSMaterializerAccumulator :
284303 rec_arrays = [self .rec (ary ) for ary in expr .arrays ]
285304 new_arrays = tuple (ary .expr for ary in rec_arrays )
286305 return _materialize_if_mpms (
287306 expr .replace_if_different (arrays = new_arrays ),
288- self .nsuccessors [expr ],
307+ self .successors [expr ],
289308 rec_arrays )
290309
291310 def map_roll (self , expr : Roll ) -> MPMSMaterializerAccumulator :
292311 rec_array = self .rec (expr .array )
293312 return _materialize_if_mpms (
294313 expr .replace_if_different (array = rec_array .expr ),
295- self .nsuccessors [expr ],
314+ self .successors [expr ],
296315 (rec_array ,))
297316
298317 def map_axis_permutation (self , expr : AxisPermutation
299318 ) -> MPMSMaterializerAccumulator :
300319 rec_array = self .rec (expr .array )
301320 return _materialize_if_mpms (
302321 expr .replace_if_different (array = rec_array .expr ),
303- self .nsuccessors [expr ],
322+ self .successors [expr ],
304323 (rec_array ,))
305324
306325 def _map_index_base (self , expr : IndexBase ) -> MPMSMaterializerAccumulator :
@@ -319,7 +338,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
319338 else new_indices )
320339 return _materialize_if_mpms (
321340 expr .replace_if_different (array = rec_array .expr , indices = new_indices ),
322- self .nsuccessors [expr ],
341+ self .successors [expr ],
323342 (rec_array , * tuple (rec_indices .values ())))
324343
325344 def map_basic_index (self , expr : BasicIndex ) -> MPMSMaterializerAccumulator :
@@ -338,15 +357,15 @@ def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator:
338357 rec_array = self .rec (expr .array )
339358 return _materialize_if_mpms (
340359 expr .replace_if_different (array = rec_array .expr ),
341- self .nsuccessors [expr ],
360+ self .successors [expr ],
342361 (rec_array ,))
343362
344363 def map_einsum (self , expr : Einsum ) -> MPMSMaterializerAccumulator :
345364 rec_args = [self .rec (ary ) for ary in expr .args ]
346365 new_args = tuple (ary .expr for ary in rec_args )
347366 return _materialize_if_mpms (
348367 expr .replace_if_different (args = new_args ),
349- self .nsuccessors [expr ],
368+ self .successors [expr ],
350369 rec_args )
351370
352371 def map_dict_of_named_arrays (self , expr : DictOfNamedArrays
@@ -427,8 +446,8 @@ def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
427446 ====== ======== =======
428447
429448 """
430- from pytato .analysis import get_num_nodes , get_num_tags_of_type , get_nusers
431- materializer = MPMSMaterializer (get_nusers (expr ))
449+ from pytato .analysis import get_list_of_users , get_num_nodes , get_num_tags_of_type
450+ materializer = MPMSMaterializer (get_list_of_users (expr ))
432451
433452 if isinstance (expr , Array ):
434453 res = materializer (expr ).expr
0 commit comments