|
83 | 83 | .. autofunction:: get_num_node_instances_of |
84 | 84 | .. autofunction:: collect_node_instances_of |
85 | 85 | .. autofunction:: get_num_tags_of_type |
| 86 | +.. autofunction:: collect_materialized_nodes |
86 | 87 | """ |
87 | 88 |
|
88 | 89 |
|
@@ -995,6 +996,42 @@ def get_num_tags_of_type( |
995 | 996 | # }}} |
996 | 997 |
|
997 | 998 |
|
| 999 | +# {{{ collect_materialized_nodes |
| 1000 | + |
| 1001 | +# FIXME: optimize_mapper? |
| 1002 | +class MaterializedNodeCollector(NodeCollector): |
| 1003 | + """Return the nodes in a DAG that are materialized.""" |
| 1004 | + def __init__( |
| 1005 | + self, |
| 1006 | + traverse_functions: bool = True) -> None: |
| 1007 | + def collect_fn(expr: ArrayOrNames | FunctionDefinition) -> bool: |
| 1008 | + # FIXME: This isn't right; need is_materialized() function from |
| 1009 | + # https://github.com/inducer/pytato/pull/623 |
| 1010 | + from pytato.tags import ImplStored |
| 1011 | + return bool(expr.tags_of_type(ImplStored)) |
| 1012 | + |
| 1013 | + super().__init__( |
| 1014 | + collect_fn=collect_fn, |
| 1015 | + traverse_functions=traverse_functions) |
| 1016 | + |
| 1017 | + @override |
| 1018 | + def clone_for_callee(self, function: FunctionDefinition) -> Self: |
| 1019 | + return type(self)( |
| 1020 | + traverse_functions=self.traverse_functions) |
| 1021 | + |
| 1022 | + |
| 1023 | +def collect_materialized_nodes( |
| 1024 | + outputs: ArrayOrNames | FunctionDefinition, |
| 1025 | + node_type: type[ArrayOrNames | FunctionDefinition], |
| 1026 | + traverse_functions: bool = True) -> NodeSet: |
| 1027 | + """Return the nodes in DAG *outputs* that are materialized.""" |
| 1028 | + mnc = MaterializedNodeCollector( |
| 1029 | + traverse_functions=traverse_functions) |
| 1030 | + return mnc(outputs) |
| 1031 | + |
| 1032 | +# }}} |
| 1033 | + |
| 1034 | + |
998 | 1035 | # {{{ PytatoKeyBuilder |
999 | 1036 |
|
1000 | 1037 | class PytatoKeyBuilder(LoopyKeyBuilder): |
|
0 commit comments