|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
10 | 10 | import inspect |
11 | | -from collections import defaultdict |
| 11 | +from collections import Counter, defaultdict |
12 | 12 | from collections.abc import Callable, Iterable |
13 | | -from dataclasses import dataclass |
| 13 | +from copy import deepcopy |
| 14 | +from dataclasses import dataclass, replace |
14 | 15 | from types import NoneType |
15 | 16 | from typing import Any, Generic, cast, get_origin, get_type_hints, overload |
16 | 17 |
|
|
20 | 21 | from pydantic_graph._utils import UNSET, Unset |
21 | 22 | from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder |
22 | 23 | from pydantic_graph.beta.graph import Graph |
23 | | -from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID |
| 24 | +from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, is_placeholder_node_id |
24 | 25 | from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction |
25 | 26 | from pydantic_graph.beta.node import ( |
26 | 27 | EndNode, |
|
59 | 60 | T = TypeVar('T', infer_variance=True) |
60 | 61 |
|
61 | 62 |
|
| 63 | +# TODO: Make this kw-only and drop init=False..? |
62 | 64 | @dataclass(init=False) |
63 | 65 | class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): |
64 | 66 | """A builder for constructing executable graph definitions. |
@@ -440,7 +442,9 @@ def match( |
440 | 442 | node_id = NodeID(self._get_new_decision_id()) |
441 | 443 | decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None) |
442 | 444 | new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[]) |
443 | | - return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder) |
| 445 | + return DecisionBranchBuilder( |
| 446 | + _decision=decision, _source=source, _matches=matches, _path_builder=new_path_builder |
| 447 | + ) |
444 | 448 |
|
445 | 449 | def match_node( |
446 | 450 | self, |
@@ -663,8 +667,11 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: |
663 | 667 | # TODO(P2): Allow the user to specify the parent forks; only infer them if _not_ specified |
664 | 668 | # TODO(P2): Verify that any user-specified parent forks are _actually_ valid parent forks, and if not, generate a helpful error message |
665 | 669 | # TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges |
666 | | - nodes = self._nodes |
667 | | - edges_by_source = self._edges_by_source |
| 670 | + |
| 671 | + nodes = deepcopy(self._nodes) |
| 672 | + edges_by_source = deepcopy(self._edges_by_source) |
| 673 | + |
| 674 | + nodes, edges_by_source = _replace_placeholder_node_ids(nodes, edges_by_source) |
668 | 675 | nodes, edges_by_source = _flatten_paths(nodes, edges_by_source) |
669 | 676 | nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) |
670 | 677 | parent_forks = _collect_dominating_forks(nodes, edges_by_source) |
@@ -843,3 +850,81 @@ def _handle_path(path: Path, last_source_id: NodeID): |
843 | 850 | dominating_forks[join_id] = dominating_fork |
844 | 851 |
|
845 | 852 | return dominating_forks |
| 853 | + |
| 854 | + |
| 855 | +def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]): |
| 856 | + node_id_remapping = _build_placeholder_node_id_remapping(nodes) |
| 857 | + replaced_nodes = { |
| 858 | + node_id_remapping.get(name, name): _update_node_with_id_remapping(node, node_id_remapping) |
| 859 | + for name, node in nodes.items() |
| 860 | + } |
| 861 | + replaced_edges_by_source = { |
| 862 | + node_id_remapping.get(source, source): [_update_path_with_id_remapping(p, node_id_remapping) for p in paths] |
| 863 | + for source, paths in edges_by_source.items() |
| 864 | + } |
| 865 | + return replaced_nodes, replaced_edges_by_source |
| 866 | + |
| 867 | + |
| 868 | +def _build_placeholder_node_id_remapping(nodes: dict[NodeID, AnyNode]) -> dict[NodeID, NodeID]: |
| 869 | + """The determinism of the generated remapping here is dependent on the determinism of the ordering of the `nodes` dict. |
| 870 | +
|
| 871 | + Note: If we want to generate more interesting names, we could try to make use of information about the edges |
| 872 | + into/out of the relevant nodes. I'm not sure if there's a good use case for that though so I didn't bother for now. |
| 873 | + """ |
| 874 | + counter = Counter[str]() |
| 875 | + remapping: dict[NodeID, NodeID] = {} |
| 876 | + for node_id, node in nodes.items(): |
| 877 | + if not is_placeholder_node_id(node_id): |
| 878 | + continue |
| 879 | + label = type(node).__name__.lower() |
| 880 | + counter[label] = count = counter[label] + 1 |
| 881 | + remapping[node_id] = NodeID(f'{label}_{count}') |
| 882 | + return remapping |
| 883 | + |
| 884 | + |
| 885 | +def _update_node_with_id_remapping(node: AnyNode, node_id_remapping: dict[NodeID, NodeID]) -> AnyNode: |
| 886 | + if isinstance(node, Step): |
| 887 | + # Even though steps are frozen, we use object.__setattr__ to overrule that and change the id value to make it |
| 888 | + # work with NodeStep. |
| 889 | + # Note: we have already deepcopied the inputs to this function so it should be okay to make mutations, |
| 890 | + # this could change if we change the code surrounding the code paths leading to this function call though. |
| 891 | + object.__setattr__(node, 'id', node_id_remapping.get(node.id, node.id)) |
| 892 | + elif isinstance(node, Join): |
| 893 | + node = replace(node, id=JoinID(node_id_remapping.get(node.id, node.id))) |
| 894 | + elif isinstance(node, Fork): |
| 895 | + node = replace(node, id=ForkID(node_id_remapping.get(node.id, node.id))) |
| 896 | + elif isinstance(node, Decision): |
| 897 | + node = replace( |
| 898 | + node, |
| 899 | + id=node_id_remapping.get(node.id, node.id), |
| 900 | + branches=[ |
| 901 | + replace(branch, path=_update_path_with_id_remapping(branch.path, node_id_remapping)) |
| 902 | + for branch in node.branches |
| 903 | + ], |
| 904 | + ) |
| 905 | + return node |
| 906 | + |
| 907 | + |
| 908 | +def _update_path_with_id_remapping(path: Path, node_id_remapping: dict[NodeID, NodeID]) -> Path: |
| 909 | + path = replace(path) # prevent mutating the input; not technically necessary but could make debugging easier later |
| 910 | + for i, item in enumerate(path.items): |
| 911 | + if isinstance(item, MapMarker): |
| 912 | + downstream_join_id = item.downstream_join_id |
| 913 | + if downstream_join_id is not None: |
| 914 | + downstream_join_id = JoinID(node_id_remapping.get(downstream_join_id, downstream_join_id)) |
| 915 | + path.items[i] = replace( |
| 916 | + item, |
| 917 | + fork_id=ForkID(node_id_remapping.get(item.fork_id, item.fork_id)), |
| 918 | + downstream_join_id=downstream_join_id, |
| 919 | + ) |
| 920 | + elif isinstance(item, BroadcastMarker): |
| 921 | + path.items[i] = replace( |
| 922 | + item, |
| 923 | + fork_id=ForkID(node_id_remapping.get(item.fork_id, item.fork_id)), |
| 924 | + paths=[_update_path_with_id_remapping(p, node_id_remapping) for p in item.paths], |
| 925 | + ) |
| 926 | + elif isinstance(item, DestinationMarker): |
| 927 | + path.items[i] = replace( |
| 928 | + item, destination_id=node_id_remapping.get(item.destination_id, item.destination_id) |
| 929 | + ) |
| 930 | + return path |
0 commit comments