Skip to content

Commit e5dcca7

Browse files
committed
add get_list_of_users function in analysis
1 parent 37bca6b commit e5dcca7

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

pytato/analysis/__init__.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
THE SOFTWARE.
2727
"""
2828

29+
from collections import defaultdict
2930
from typing import TYPE_CHECKING, Any, overload
3031

3132
from orderedsets import FrozenOrderedSet
@@ -63,6 +64,7 @@
6364
.. currentmodule:: pytato.analysis
6465
6566
.. autofunction:: get_nusers
67+
.. autofunction:: get_list_of_users
6668
6769
.. autofunction:: is_einsum_similar_to_subscript
6870
@@ -82,12 +84,12 @@
8284
"""
8385

8486

85-
# {{{ NUserCollector
87+
# {{{ ListOfUsersCollector
8688

87-
class NUserCollector(Mapper[None, None, []]):
89+
class ListOfUsersCollector(Mapper[None, None, []]):
8890
"""
89-
A :class:`pytato.transform.CachedWalkMapper` that records the number of
90-
times an array expression is a direct dependency of other nodes.
91+
A :class:`pytato.transform.CachedWalkMapper` that records, for each array
92+
expression, the nodes that directly depend on it.
9193
9294
.. note::
9395
@@ -97,10 +99,9 @@ class NUserCollector(Mapper[None, None, []]):
9799
send's data.
98100
"""
99101
def __init__(self) -> None:
100-
from collections import defaultdict
101102
super().__init__()
102103
self._visited_ids: set[int] = set()
103-
self.nusers: dict[Array, int] = defaultdict(lambda: 0)
104+
self.array_to_users: dict[Array, list[ArrayOrNames]] = defaultdict(list)
104105

105106
def rec(self, expr: ArrayOrNames) -> None:
106107
# See CachedWalkMapper.rec on why we chose id(x) as the cache key.
@@ -113,38 +114,38 @@ def rec(self, expr: ArrayOrNames) -> None:
113114

114115
def map_index_lambda(self, expr: IndexLambda) -> None:
115116
for ary in expr.bindings.values():
116-
self.nusers[ary] += 1
117+
self.array_to_users[ary].append(expr)
117118
self.rec(ary)
118119

119120
for dim in expr.shape:
120121
if isinstance(dim, Array):
121-
self.nusers[dim] += 1
122+
self.array_to_users[dim].append(expr)
122123
self.rec(dim)
123124

124125
def map_stack(self, expr: Stack) -> None:
125126
for ary in expr.arrays:
126-
self.nusers[ary] += 1
127+
self.array_to_users[ary].append(expr)
127128
self.rec(ary)
128129

129130
def map_concatenate(self, expr: Concatenate) -> None:
130131
for ary in expr.arrays:
131-
self.nusers[ary] += 1
132+
self.array_to_users[ary].append(expr)
132133
self.rec(ary)
133134

134135
def map_loopy_call(self, expr: LoopyCall) -> None:
135136
for ary in expr.bindings.values():
136137
if isinstance(ary, Array):
137-
self.nusers[ary] += 1
138+
self.array_to_users[ary].append(expr)
138139
self.rec(ary)
139140

140141
def map_einsum(self, expr: Einsum) -> None:
141142
for ary in expr.args:
142-
self.nusers[ary] += 1
143+
self.array_to_users[ary].append(expr)
143144
self.rec(ary)
144145

145146
for dim in expr.shape:
146147
if isinstance(dim, Array):
147-
self.nusers[dim] += 1
148+
self.array_to_users[dim].append(expr)
148149
self.rec(dim)
149150

150151
def map_named_array(self, expr: NamedArray) -> None:
@@ -155,20 +156,20 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
155156
self.rec(child)
156157

157158
def _map_index_base(self, expr: IndexBase) -> None:
158-
self.nusers[expr.array] += 1
159+
self.array_to_users[expr.array].append(expr)
159160
self.rec(expr.array)
160161

161162
for idx in expr.indices:
162163
if isinstance(idx, Array):
163-
self.nusers[idx] += 1
164+
self.array_to_users[idx].append(expr)
164165
self.rec(idx)
165166

166167
map_basic_index = _map_index_base
167168
map_contiguous_advanced_index = _map_index_base
168169
map_non_contiguous_advanced_index = _map_index_base
169170

170171
def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None:
171-
self.nusers[expr.array] += 1
172+
self.array_to_users[expr.array].append(expr)
172173
self.rec(expr.array)
173174

174175
map_roll = _map_index_remapping_base
@@ -178,7 +179,7 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None:
178179
def _map_input_base(self, expr: InputArgumentBase) -> None:
179180
for dim in expr.shape:
180181
if isinstance(dim, Array):
181-
self.nusers[dim] += 1
182+
self.array_to_users[dim].append(expr)
182183
self.rec(dim)
183184

184185
map_placeholder = _map_input_base
@@ -189,20 +190,20 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder
189190
) -> None:
190191
# Note: We do not consider 'expr.send.data' as a predecessor of *expr*,
191192
# as there is no dataflow from *expr.send.data* to *expr*
192-
self.nusers[expr.passthrough_data] += 1
193+
self.array_to_users[expr.passthrough_data].append(expr)
193194
self.rec(expr.passthrough_data)
194195
self.rec(expr.send.data)
195196

196197
def map_distributed_recv(self, expr: DistributedRecv) -> None:
197198
for dim in expr.shape:
198199
if isinstance(dim, Array):
199-
self.nusers[dim] += 1
200+
self.array_to_users[dim].append(expr)
200201
self.rec(dim)
201202

202203
def map_call(self, expr: Call) -> None:
203204
for ary in expr.bindings.values():
204205
if isinstance(ary, Array):
205-
self.nusers[ary] += 1
206+
self.array_to_users[ary].append(expr)
206207
self.rec(ary)
207208

208209
def map_named_call_result(self, expr: NamedCallResult) -> None:
@@ -216,9 +217,21 @@ def get_nusers(outputs: ArrayOrNames) -> Mapping[Array, int]:
216217
For the DAG *outputs*, returns the mapping from each array node to the number of
217218
nodes using its value within the DAG given by *outputs*.
218219
"""
219-
nuser_collector = NUserCollector()
220-
nuser_collector(outputs)
221-
return nuser_collector.nusers
220+
list_of_users_collector = ListOfUsersCollector()
221+
list_of_users_collector(outputs)
222+
return defaultdict(int, {
223+
ary: len(users)
224+
for ary, users in list_of_users_collector.array_to_users.items()})
225+
226+
227+
def get_list_of_users(outputs: ArrayOrNames) -> Mapping[Array, list[ArrayOrNames]]:
228+
"""
229+
For the DAG *outputs*, returns the mapping from each array node to the list of
230+
nodes using its value within the DAG given by *outputs*.
231+
"""
232+
list_of_users_collector = ListOfUsersCollector()
233+
list_of_users_collector(outputs)
234+
return list_of_users_collector.array_to_users
222235

223236

224237
# {{{ is_einsum_similar_to_subscript
@@ -482,7 +495,6 @@ def __init__(
482495
) -> None:
483496
super().__init__(_visited_functions=_visited_functions)
484497

485-
from collections import defaultdict
486498
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
487499
self.count_duplicates = count_duplicates
488500

@@ -562,7 +574,6 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
562574
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
563575
super().__init__(_visited_functions=_visited_functions)
564576

565-
from collections import defaultdict
566577
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
567578

568579
def get_cache_key(self, expr: ArrayOrNames) -> int:

0 commit comments

Comments
 (0)