Skip to content

Commit 0c70b89

Browse files
committed
add collect_materialized_nodes
1 parent 8ba052c commit 0c70b89

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

pytato/analysis/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
.. autofunction:: get_num_node_instances_of
8484
.. autofunction:: collect_node_instances_of
8585
.. autofunction:: get_num_tags_of_type
86+
.. autofunction:: collect_materialized_nodes
8687
"""
8788

8889

@@ -995,6 +996,42 @@ def get_num_tags_of_type(
995996
# }}}
996997

997998

999+
# {{{ collect_materialized_nodes
1000+
1001+
# FIXME: optimize_mapper?
1002+
class MaterializedNodeCollector(NodeCollector):
1003+
"""Return the nodes in a DAG that are materialized."""
1004+
def __init__(
1005+
self,
1006+
traverse_functions: bool = True) -> None:
1007+
def collect_fn(expr: ArrayOrNames | FunctionDefinition) -> bool:
1008+
# FIXME: This isn't right; need is_materialized() function from
1009+
# https://github.com/inducer/pytato/pull/623
1010+
from pytato.tags import ImplStored
1011+
return bool(expr.tags_of_type(ImplStored))
1012+
1013+
super().__init__(
1014+
collect_fn=collect_fn,
1015+
traverse_functions=traverse_functions)
1016+
1017+
@override
1018+
def clone_for_callee(self, function: FunctionDefinition) -> Self:
1019+
return type(self)(
1020+
traverse_functions=self.traverse_functions)
1021+
1022+
1023+
def collect_materialized_nodes(
1024+
outputs: ArrayOrNames | FunctionDefinition,
1025+
node_type: type[ArrayOrNames | FunctionDefinition],
1026+
traverse_functions: bool = True) -> NodeSet:
1027+
"""Return the nodes in DAG *outputs* that are materialized."""
1028+
mnc = MaterializedNodeCollector(
1029+
traverse_functions=traverse_functions)
1030+
return mnc(outputs)
1031+
1032+
# }}}
1033+
1034+
9981035
# {{{ PytatoKeyBuilder
9991036

10001037
class PytatoKeyBuilder(LoopyKeyBuilder):

test/test_pytato.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,12 @@ def f(a):
993993
== frozenset([y, b]))
994994

995995

996+
def test_collect_materialized_nodes():
997+
# FIXME: Add tests after fixing collect_materialized_nodes to use
998+
# is_materialized() function from https://github.com/inducer/pytato/pull/623
999+
pytest.fail("Not implemented yet.")
1000+
1001+
9961002
def test_rec_get_user_nodes():
9971003
x1 = pt.make_placeholder("x1", shape=(10, 4))
9981004
x2 = pt.make_placeholder("x2", shape=(10, 4))

0 commit comments

Comments
 (0)