2626THE SOFTWARE.
2727"""
2828
29+ from collections import defaultdict
2930from typing import TYPE_CHECKING , Any , overload
3031
3132from orderedsets import FrozenOrderedSet
6364.. currentmodule:: pytato.analysis
6465
6566.. autofunction:: get_nusers
67+ .. autofunction:: get_list_of_users
6668
6769.. autofunction:: is_einsum_similar_to_subscript
6870
8284"""
8385
8486
85- # {{{ NUserCollector
87+ # {{{ ListOfUsersCollector
8688
87- class NUserCollector (Mapper [None , None , []]):
89+ class ListOfUsersCollector (Mapper [None , None , []]):
8890 """
89- A :class:`pytato.transform.CachedWalkMapper` that records the number of
90- times an array expression is a direct dependency of other nodes .
91+ A :class:`pytato.transform.CachedWalkMapper` that records, for each array
92+ expression, the nodes that directly depend on it .
9193
9294 .. note::
9395
@@ -97,10 +99,9 @@ class NUserCollector(Mapper[None, None, []]):
9799 send's data.
98100 """
99101 def __init__ (self ) -> None :
100- from collections import defaultdict
101102 super ().__init__ ()
102103 self ._visited_ids : set [int ] = set ()
103- self .nusers : dict [Array , int ] = defaultdict (lambda : 0 )
104+ self .array_to_users : dict [Array , list [ ArrayOrNames ]] = defaultdict (list )
104105
105106 def rec (self , expr : ArrayOrNames ) -> None :
106107 # See CachedWalkMapper.rec on why we chose id(x) as the cache key.
@@ -113,38 +114,38 @@ def rec(self, expr: ArrayOrNames) -> None:
113114
114115 def map_index_lambda (self , expr : IndexLambda ) -> None :
115116 for ary in expr .bindings .values ():
116- self .nusers [ary ] += 1
117+ self .array_to_users [ary ]. append ( expr )
117118 self .rec (ary )
118119
119120 for dim in expr .shape :
120121 if isinstance (dim , Array ):
121- self .nusers [dim ] += 1
122+ self .array_to_users [dim ]. append ( expr )
122123 self .rec (dim )
123124
124125 def map_stack (self , expr : Stack ) -> None :
125126 for ary in expr .arrays :
126- self .nusers [ary ] += 1
127+ self .array_to_users [ary ]. append ( expr )
127128 self .rec (ary )
128129
129130 def map_concatenate (self , expr : Concatenate ) -> None :
130131 for ary in expr .arrays :
131- self .nusers [ary ] += 1
132+ self .array_to_users [ary ]. append ( expr )
132133 self .rec (ary )
133134
134135 def map_loopy_call (self , expr : LoopyCall ) -> None :
135136 for ary in expr .bindings .values ():
136137 if isinstance (ary , Array ):
137- self .nusers [ary ] += 1
138+ self .array_to_users [ary ]. append ( expr )
138139 self .rec (ary )
139140
140141 def map_einsum (self , expr : Einsum ) -> None :
141142 for ary in expr .args :
142- self .nusers [ary ] += 1
143+ self .array_to_users [ary ]. append ( expr )
143144 self .rec (ary )
144145
145146 for dim in expr .shape :
146147 if isinstance (dim , Array ):
147- self .nusers [dim ] += 1
148+ self .array_to_users [dim ]. append ( expr )
148149 self .rec (dim )
149150
150151 def map_named_array (self , expr : NamedArray ) -> None :
@@ -155,20 +156,20 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
155156 self .rec (child )
156157
157158 def _map_index_base (self , expr : IndexBase ) -> None :
158- self .nusers [expr .array ] += 1
159+ self .array_to_users [expr .array ]. append ( expr )
159160 self .rec (expr .array )
160161
161162 for idx in expr .indices :
162163 if isinstance (idx , Array ):
163- self .nusers [idx ] += 1
164+ self .array_to_users [idx ]. append ( expr )
164165 self .rec (idx )
165166
166167 map_basic_index = _map_index_base
167168 map_contiguous_advanced_index = _map_index_base
168169 map_non_contiguous_advanced_index = _map_index_base
169170
170171 def _map_index_remapping_base (self , expr : IndexRemappingBase ) -> None :
171- self .nusers [expr .array ] += 1
172+ self .array_to_users [expr .array ]. append ( expr )
172173 self .rec (expr .array )
173174
174175 map_roll = _map_index_remapping_base
@@ -178,7 +179,7 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None:
178179 def _map_input_base (self , expr : InputArgumentBase ) -> None :
179180 for dim in expr .shape :
180181 if isinstance (dim , Array ):
181- self .nusers [dim ] += 1
182+ self .array_to_users [dim ]. append ( expr )
182183 self .rec (dim )
183184
184185 map_placeholder = _map_input_base
@@ -189,20 +190,20 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder
189190 ) -> None :
190191 # Note: We do not consider 'expr.send.data' as a predecessor of *expr*,
191192 # as there is no dataflow from *expr.send.data* to *expr*
192- self .nusers [expr .passthrough_data ] += 1
193+ self .array_to_users [expr .passthrough_data ]. append ( expr )
193194 self .rec (expr .passthrough_data )
194195 self .rec (expr .send .data )
195196
196197 def map_distributed_recv (self , expr : DistributedRecv ) -> None :
197198 for dim in expr .shape :
198199 if isinstance (dim , Array ):
199- self .nusers [dim ] += 1
200+ self .array_to_users [dim ]. append ( expr )
200201 self .rec (dim )
201202
202203 def map_call (self , expr : Call ) -> None :
203204 for ary in expr .bindings .values ():
204205 if isinstance (ary , Array ):
205- self .nusers [ary ] += 1
206+ self .array_to_users [ary ]. append ( expr )
206207 self .rec (ary )
207208
208209 def map_named_call_result (self , expr : NamedCallResult ) -> None :
@@ -216,9 +217,21 @@ def get_nusers(outputs: ArrayOrNames) -> Mapping[Array, int]:
216217 For the DAG *outputs*, returns the mapping from each array node to the number of
217218 nodes using its value within the DAG given by *outputs*.
218219 """
219- nuser_collector = NUserCollector ()
220- nuser_collector (outputs )
221- return nuser_collector .nusers
220+ list_of_users_collector = ListOfUsersCollector ()
221+ list_of_users_collector (outputs )
222+ return defaultdict (int , {
223+ ary : len (users )
224+ for ary , users in list_of_users_collector .array_to_users .items ()})
225+
226+
227+ def get_list_of_users (outputs : ArrayOrNames ) -> Mapping [Array , list [ArrayOrNames ]]:
228+ """
229+ For the DAG *outputs*, returns the mapping from each array node to the list of
230+ nodes using its value within the DAG given by *outputs*.
231+ """
232+ list_of_users_collector = ListOfUsersCollector ()
233+ list_of_users_collector (outputs )
234+ return list_of_users_collector .array_to_users
222235
223236
224237# {{{ is_einsum_similar_to_subscript
@@ -482,7 +495,6 @@ def __init__(
482495 ) -> None :
483496 super ().__init__ (_visited_functions = _visited_functions )
484497
485- from collections import defaultdict
486498 self .expr_type_counts : dict [type [Any ], int ] = defaultdict (int )
487499 self .count_duplicates = count_duplicates
488500
@@ -562,7 +574,6 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
562574 def __init__ (self , _visited_functions : set [Any ] | None = None ) -> None :
563575 super ().__init__ (_visited_functions = _visited_functions )
564576
565- from collections import defaultdict
566577 self .expr_multiplicity_counts : dict [Array , int ] = defaultdict (int )
567578
568579 def get_cache_key (self , expr : ArrayOrNames ) -> int :
0 commit comments