@@ -172,7 +172,7 @@ def add(
172172
173173
174174def _materialize_if_mpms (expr : Array ,
175- nsuccessors : int ,
175+ successors : list [ ArrayOrNames ] ,
176176 predecessors : Iterable [MPMSMaterializerAccumulator ]
177177 ) -> MPMSMaterializerAccumulator :
178178 """
@@ -189,6 +189,17 @@ def _materialize_if_mpms(expr: Array,
189189 (pred .materialized_predecessors for pred in predecessors ),
190190 cast ("frozenset[Array]" , frozenset ()))
191191
192+ nsuccessors = 0
193+ for successor in successors :
194+ # Handle indexing with heavy reuse, if the sizes are known ahead of time
195+ if (
196+ isinstance (successor , IndexBase )
197+ and isinstance (successor .size , int )
198+ and isinstance (expr .size , int )):
199+ nsuccessors += (successor .size // expr .size ) if expr .size else 0
200+ else :
201+ nsuccessors += 1
202+
192203 if nsuccessors > 1 and len (materialized_predecessors ) > 1 :
193204 new_expr = expr .tagged (ImplStored ())
194205 return MPMSMaterializerAccumulator (frozenset ([new_expr ]), new_expr )
@@ -201,14 +212,15 @@ class MPMSMaterializer(
201212 """
202213 See :func:`materialize_with_mpms` for an explanation.
203214
204- .. attribute:: nsuccessors
215+ .. attribute:: successors
205216
206217 A mapping from a node in the expression graph (i.e. an
207- :class:`~pytato.Array`) to its number of successors.
218+ :class:`~pytato.Array`) to a list of its successors (possibly including
219+ multiple references to the same successor if it uses the node multiple times).
208220 """
209221 def __init__ (
210222 self ,
211- nsuccessors : Mapping [Array , int ],
223+ successors : Mapping [Array , list [ ArrayOrNames ] ],
212224 _cache : MPMSMaterializerCache | None = None ):
213225 err_on_collision = __debug__
214226 err_on_created_duplicate = __debug__
@@ -221,7 +233,7 @@ def __init__(
221233 # Does not support functions, so function_cache is ignored
222234 super ().__init__ (err_on_collision = err_on_collision , _cache = _cache )
223235
224- self .nsuccessors : Mapping [Array , int ] = nsuccessors
236+ self .successors : Mapping [Array , list [ ArrayOrNames ]] = successors
225237
226238 @override
227239 def _cache_add (
@@ -269,38 +281,38 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator:
269281 for bnd_name , bnd in children_rec .items ()})
270282 return _materialize_if_mpms (
271283 expr .replace_if_different (bindings = new_children ),
272- self .nsuccessors [expr ],
284+ self .successors [expr ],
273285 children_rec .values ())
274286
275287 def map_stack (self , expr : Stack ) -> MPMSMaterializerAccumulator :
276288 rec_arrays = [self .rec (ary ) for ary in expr .arrays ]
277289 new_arrays = tuple (ary .expr for ary in rec_arrays )
278290 return _materialize_if_mpms (
279291 expr .replace_if_different (arrays = new_arrays ),
280- self .nsuccessors [expr ],
292+ self .successors [expr ],
281293 rec_arrays )
282294
283295 def map_concatenate (self , expr : Concatenate ) -> MPMSMaterializerAccumulator :
284296 rec_arrays = [self .rec (ary ) for ary in expr .arrays ]
285297 new_arrays = tuple (ary .expr for ary in rec_arrays )
286298 return _materialize_if_mpms (
287299 expr .replace_if_different (arrays = new_arrays ),
288- self .nsuccessors [expr ],
300+ self .successors [expr ],
289301 rec_arrays )
290302
291303 def map_roll (self , expr : Roll ) -> MPMSMaterializerAccumulator :
292304 rec_array = self .rec (expr .array )
293305 return _materialize_if_mpms (
294306 expr .replace_if_different (array = rec_array .expr ),
295- self .nsuccessors [expr ],
307+ self .successors [expr ],
296308 (rec_array ,))
297309
298310 def map_axis_permutation (self , expr : AxisPermutation
299311 ) -> MPMSMaterializerAccumulator :
300312 rec_array = self .rec (expr .array )
301313 return _materialize_if_mpms (
302314 expr .replace_if_different (array = rec_array .expr ),
303- self .nsuccessors [expr ],
315+ self .successors [expr ],
304316 (rec_array ,))
305317
306318 def _map_index_base (self , expr : IndexBase ) -> MPMSMaterializerAccumulator :
@@ -319,7 +331,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator:
319331 else new_indices )
320332 return _materialize_if_mpms (
321333 expr .replace_if_different (array = rec_array .expr , indices = new_indices ),
322- self .nsuccessors [expr ],
334+ self .successors [expr ],
323335 (rec_array , * tuple (rec_indices .values ())))
324336
325337 def map_basic_index (self , expr : BasicIndex ) -> MPMSMaterializerAccumulator :
@@ -338,15 +350,15 @@ def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator:
338350 rec_array = self .rec (expr .array )
339351 return _materialize_if_mpms (
340352 expr .replace_if_different (array = rec_array .expr ),
341- self .nsuccessors [expr ],
353+ self .successors [expr ],
342354 (rec_array ,))
343355
344356 def map_einsum (self , expr : Einsum ) -> MPMSMaterializerAccumulator :
345357 rec_args = [self .rec (ary ) for ary in expr .args ]
346358 new_args = tuple (ary .expr for ary in rec_args )
347359 return _materialize_if_mpms (
348360 expr .replace_if_different (args = new_args ),
349- self .nsuccessors [expr ],
361+ self .successors [expr ],
350362 rec_args )
351363
352364 def map_dict_of_named_arrays (self , expr : DictOfNamedArrays
@@ -427,8 +439,8 @@ def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
427439 ====== ======== =======
428440
429441 """
430- from pytato .analysis import get_num_nodes , get_num_tags_of_type , get_nusers
431- materializer = MPMSMaterializer (get_nusers (expr ))
442+ from pytato .analysis import get_list_of_users , get_num_nodes , get_num_tags_of_type
443+ materializer = MPMSMaterializer (get_list_of_users (expr ))
432444
433445 if isinstance (expr , Array ):
434446 res = materializer (expr ).expr
0 commit comments