Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 0 additions & 144 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -8147,150 +8147,6 @@
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 13,
"endColumn": 37,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 11,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 4,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 52,
"endColumn": 67,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 52,
"endColumn": 67,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 52,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 13,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 19,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 20,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportUnusedParameter",
"range": {
"startColumn": 30,
"endColumn": 34,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 19,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 37,
"lineCount": 1
}
},
{
"code": "reportUnusedParameter",
"range": {
"startColumn": 39,
"endColumn": 43,
"lineCount": 1
}
},
{
"code": "reportUnusedParameter",
"range": {
"startColumn": 36,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportIncompatibleMethodOverride",
"range": {
Expand Down
61 changes: 36 additions & 25 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
THE SOFTWARE.
"""

from collections import defaultdict
from typing import TYPE_CHECKING, Any, overload

from orderedsets import FrozenOrderedSet
Expand Down Expand Up @@ -63,6 +64,7 @@
.. currentmodule:: pytato.analysis

.. autofunction:: get_nusers
.. autofunction:: get_list_of_users

.. autofunction:: is_einsum_similar_to_subscript

Expand All @@ -82,12 +84,12 @@
"""


# {{{ NUserCollector
# {{{ ListOfUsersCollector

class NUserCollector(Mapper[None, None, []]):
class ListOfUsersCollector(Mapper[None, None, []]):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
times an array expression is a direct dependency of other nodes.
A :class:`pytato.transform.CachedWalkMapper` that records, for each array
expression, the nodes that directly depend on it.

.. note::

Expand All @@ -97,10 +99,9 @@ class NUserCollector(Mapper[None, None, []]):
send's data.
"""
def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self._visited_ids: set[int] = set()
self.nusers: dict[Array, int] = defaultdict(lambda: 0)
self.array_to_users: dict[Array, list[ArrayOrNames]] = defaultdict(list)

def rec(self, expr: ArrayOrNames) -> None:
# See CachedWalkMapper.rec on why we chose id(x) as the cache key.
Expand All @@ -113,38 +114,38 @@ def rec(self, expr: ArrayOrNames) -> None:

def map_index_lambda(self, expr: IndexLambda) -> None:
for ary in expr.bindings.values():
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

for dim in expr.shape:
if isinstance(dim, Array):
self.nusers[dim] += 1
self.array_to_users[dim].append(expr)
self.rec(dim)

def map_stack(self, expr: Stack) -> None:
for ary in expr.arrays:
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

def map_concatenate(self, expr: Concatenate) -> None:
for ary in expr.arrays:
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

def map_loopy_call(self, expr: LoopyCall) -> None:
for ary in expr.bindings.values():
if isinstance(ary, Array):
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

def map_einsum(self, expr: Einsum) -> None:
for ary in expr.args:
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

for dim in expr.shape:
if isinstance(dim, Array):
self.nusers[dim] += 1
self.array_to_users[dim].append(expr)
self.rec(dim)

def map_named_array(self, expr: NamedArray) -> None:
Expand All @@ -155,20 +156,20 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
self.rec(child)

def _map_index_base(self, expr: IndexBase) -> None:
self.nusers[expr.array] += 1
self.array_to_users[expr.array].append(expr)
self.rec(expr.array)

for idx in expr.indices:
if isinstance(idx, Array):
self.nusers[idx] += 1
self.array_to_users[idx].append(expr)
self.rec(idx)

map_basic_index = _map_index_base
map_contiguous_advanced_index = _map_index_base
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None:
self.nusers[expr.array] += 1
self.array_to_users[expr.array].append(expr)
self.rec(expr.array)

map_roll = _map_index_remapping_base
Expand All @@ -178,7 +179,7 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None:
def _map_input_base(self, expr: InputArgumentBase) -> None:
for dim in expr.shape:
if isinstance(dim, Array):
self.nusers[dim] += 1
self.array_to_users[dim].append(expr)
self.rec(dim)

map_placeholder = _map_input_base
Expand All @@ -189,20 +190,20 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder
) -> None:
# Note: We do not consider 'expr.send.data' as a predecessor of *expr*,
# as there is no dataflow from *expr.send.data* to *expr*
self.nusers[expr.passthrough_data] += 1
self.array_to_users[expr.passthrough_data].append(expr)
self.rec(expr.passthrough_data)
self.rec(expr.send.data)

def map_distributed_recv(self, expr: DistributedRecv) -> None:
for dim in expr.shape:
if isinstance(dim, Array):
self.nusers[dim] += 1
self.array_to_users[dim].append(expr)
self.rec(dim)

def map_call(self, expr: Call) -> None:
for ary in expr.bindings.values():
if isinstance(ary, Array):
self.nusers[ary] += 1
self.array_to_users[ary].append(expr)
self.rec(ary)

def map_named_call_result(self, expr: NamedCallResult) -> None:
Expand All @@ -216,9 +217,21 @@ def get_nusers(outputs: ArrayOrNames) -> Mapping[Array, int]:
For the DAG *outputs*, returns the mapping from each array node to the number of
nodes using its value within the DAG given by *outputs*.
"""
nuser_collector = NUserCollector()
nuser_collector(outputs)
return nuser_collector.nusers
list_of_users_collector = ListOfUsersCollector()
list_of_users_collector(outputs)
return defaultdict(int, {
ary: len(users)
for ary, users in list_of_users_collector.array_to_users.items()})


def get_list_of_users(outputs: ArrayOrNames) -> Mapping[Array, list[ArrayOrNames]]:
"""
For the DAG *outputs*, returns the mapping from each array node to the list of
nodes using its value within the DAG given by *outputs*.
"""
list_of_users_collector = ListOfUsersCollector()
list_of_users_collector(outputs)
return list_of_users_collector.array_to_users


# {{{ is_einsum_similar_to_subscript
Expand Down Expand Up @@ -482,7 +495,6 @@ def __init__(
) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

Expand Down Expand Up @@ -562,7 +574,6 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
Expand Down
Loading
Loading