1- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # http://www.apache.org/licenses/LICENSE-2.0
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
14- # ==============================================================================
151# The file is adapted from the Flax library (https://github.com/google/flax).
162# The credit should go to the Flax authors.
173#
2814# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2915# See the License for the specific language governing permissions and
3016# limitations under the License.
31- # ==============================================================================
3217
33- from typing import Any , Callable , Iterable , TypeVar , Hashable , Optional , Tuple , List , Dict
18+ from __future__ import annotations
19+
20+ from collections .abc import Callable , Iterable
21+ from typing import Any , TypeVar
3422
3523import jax
3624
5442
5543Node = TypeVar ('Node' )
5644Leaf = TypeVar ('Leaf' )
57-
58- KeyEntry = TypeVar ('KeyEntry' , bound = Hashable )
45+ KeyEntry = TypeVar ('KeyEntry' )
5946KeyPath = tuple [KeyEntry , ...]
6047Prefix = Any
6148RandomState = None
@@ -70,15 +57,15 @@ def _get_rand_state() -> type:
7057
7158
7259def check_consistent_aliasing (
73- node : Tuple [Any , ...],
74- prefix : Tuple [Any , ...],
60+ node : tuple [Any , ...],
61+ prefix : tuple [Any , ...],
7562 / ,
7663 * ,
77- node_prefixes : Optional [RefMap [Any , List [Tuple [PathParts , Any ]]]] = None ,
78- ):
64+ node_prefixes : RefMap [Any , list [tuple [PathParts , Any ]]] | None = None ,
65+ ) -> None :
66+ """Check that shared nodes have consistent prefixes across all paths."""
7967 node_prefixes = RefMap () if node_prefixes is None else node_prefixes
8068
81- # collect all paths and prefixes for each node
8269 for path , value in iter_graph (node ):
8370 if _is_graph_node (value ) or isinstance (value , State ):
8471 if isinstance (value , GraphNode ):
@@ -90,52 +77,34 @@ def check_consistent_aliasing(
9077 lambda : f'Trying to extract graph node from different trace level, got { value !r} '
9178 )
9279 if value in node_prefixes :
93- paths_prefixes = node_prefixes [value ]
94- paths_prefixes .append ((path , prefix ))
80+ node_prefixes [value ].append ((path , prefix ))
9581 else :
9682 node_prefixes [value ] = [(path , prefix )]
9783
98- # check for inconsistent aliasing
9984 node_msgs = []
10085 for node , paths_prefixes in node_prefixes .items ():
10186 unique_prefixes = {prefix for _ , prefix in paths_prefixes }
10287 if len (unique_prefixes ) > 1 :
103- path_prefix_repr = '\n ' .join ([f' { "/" .join (map (str , path )) if path else "<root>" } : { prefix } '
104- for path , prefix in paths_prefixes ])
105- nodes_msg = f'Node: { type (node )} \n { path_prefix_repr } '
106- node_msgs .append (nodes_msg )
88+ path_prefix_repr = '\n ' .join ([
89+ f' { "/" .join (map (str , path )) if path else "<root>" } : { prefix } '
90+ for path , prefix in paths_prefixes
91+ ])
92+ node_msgs .append (f'Node: { type (node )} \n { path_prefix_repr } ' )
10793
10894 if node_msgs :
109- raise ValueError ('Inconsistent aliasing detected. The '
110- 'following nodes have different prefixes:\n '
111- + '\n ' .join (node_msgs ))
95+ raise ValueError (
96+ 'Inconsistent aliasing detected. The following nodes have different prefixes:\n '
97+ + '\n ' .join (node_msgs )
98+ )
11299
113100
114- # -----------------------------
115- # to_tree/from_tree
116- # -----------------------------
117-
118101def broadcast_prefix (
119102 prefix_tree : Any ,
120103 full_tree : Any ,
121- prefix_is_leaf : Optional [Callable [[Any ], bool ]] = None ,
122- tree_is_leaf : Optional [Callable [[Any ], bool ]] = None ,
123- ) -> List [Any ]:
124- """
125- Broadcasts a prefix tree to a full tree.
126-
127- Args:
128- prefix_tree: A prefix tree.
129- full_tree: A full tree.
130- prefix_is_leaf: A function that checks if a prefix is a leaf.
131- tree_is_leaf: A function that checks if a tree is a leaf.
132-
133- Returns:
134- A list of prefixes.
135- """
136- # If prefix_tree is not a tree prefix of full_tree, this code can raise a
137- # ValueError; use prefix_errors to find disagreements and raise more precise
138- # error messages.
104+ prefix_is_leaf : Callable [[Any ], bool ] | None = None ,
105+ tree_is_leaf : Callable [[Any ], bool ] | None = None ,
106+ ) -> list [Any ]:
107+ """Broadcast a prefix tree to match the leaves of a full tree."""
139108 result = []
140109 num_leaves = lambda t : jax .tree_util .tree_structure (t , is_leaf = tree_is_leaf ).num_leaves
141110 add_leaves = lambda x , subtree : result .extend ([x ] * num_leaves (subtree ))
@@ -144,6 +113,12 @@ def broadcast_prefix(
144113
145114
146115class NodeStates (PyTreeNode ):
116+ """A JAX pytree wrapper that carries both a GraphDef and one or more state mappings.
117+
118+ Used by ``graph_to_tree`` / ``tree_to_graph`` to represent graph nodes as
119+ pure pytrees so that JAX transforms (vmap, jit, etc.) can operate on them.
120+ """
121+
147122 _graphdef : GraphDef [Any ] | None
148123 states : tuple [GraphStateMapping , ...]
149124 metadata : Any = field (pytree_node = False )
@@ -168,19 +143,19 @@ def from_split(
168143 / ,
169144 * states : GraphStateMapping ,
170145 metadata : Any = None ,
171- ):
146+ ) -> NodeStates :
172147 return cls (_graphdef = graphdef , states = (state , * states ), metadata = metadata )
173148
174149 @classmethod
175- def from_states (cls , state : GraphStateMapping , * states : GraphStateMapping ):
150+ def from_states (cls , state : GraphStateMapping , * states : GraphStateMapping ) -> NodeStates :
176151 return cls (_graphdef = None , states = (state , * states ), metadata = None )
177152
178153 @classmethod
179- def from_prefixes (cls , prefixes : Iterable [Any ], / , * , metadata : Any = None ):
154+ def from_prefixes (cls , prefixes : Iterable [Any ], / , * , metadata : Any = None ) -> NodeStates :
180155 return cls (_graphdef = None , states = tuple (prefixes ), metadata = metadata )
181156
182157
183- def _default_split_fn (ctx : SplitContext , path : KeyPath , prefix : Prefix , leaf : Leaf ):
158+ def _default_split_fn (ctx : SplitContext , path : KeyPath , prefix : Prefix , leaf : Leaf ) -> NodeStates :
184159 return NodeStates .from_split (* ctx .treefy_split (leaf ))
185160
186161
@@ -192,20 +167,16 @@ def graph_to_tree(
192167 split_fn : Callable [[SplitContext , KeyPath , Prefix , Leaf ], Any ] = _default_split_fn ,
193168 map_non_graph_nodes : bool = False ,
194169 check_aliasing : bool = True ,
195- ) -> Tuple [PyTree , Dict [KeyPath , SeedOrKey ]]:
196- """
197- Convert a tree of pytree objects to a tree of TreeNode objects.
198- """
170+ ) -> tuple [PyTree , dict [KeyPath , SeedOrKey ]]:
171+ """Convert a pytree that may contain graph nodes into a pure pytree of NodeStates."""
199172 leaf_prefixes = broadcast_prefix (prefix , may_have_graph_nodes , prefix_is_leaf = lambda x : x is None )
200173 leaf_keys , treedef = jax .tree_util .tree_flatten_with_path (may_have_graph_nodes )
201174
202- # Check that the number of keys and prefixes match
203175 assert len (leaf_keys ) == len (leaf_prefixes )
204176
205- # Split the tree
206177 with split_context () as (ctx , index_ref ):
207178 leaves_out = []
208- node_prefixes = RefMap [Any , list [tuple [PathParts , Any ]]]()
179+ node_prefixes : RefMap [Any , list [tuple [PathParts , Any ]]] = RefMap ()
209180 for (keypath , leaf ), leaf_prefix in zip (leaf_keys , leaf_prefixes ):
210181 if _is_graph_node (leaf ):
211182 if check_aliasing :
@@ -215,15 +186,16 @@ def graph_to_tree(
215186 if map_non_graph_nodes :
216187 leaf = split_fn (ctx , keypath , leaf_prefix , leaf )
217188 leaves_out .append (leaf )
218- pass
219189
220- find_states = states (index_ref ._mapping )
190+ # Build a dict mirroring RefMap's content via the public API, then extract
191+ # State objects from it. We must not access the private ._mapping attribute.
192+ public_map = {id (k ): (k , v ) for k , v in index_ref .items ()}
193+ find_states = states (public_map )
221194 pytree_out = jax .tree .unflatten (treedef , leaves_out )
222195 return pytree_out , find_states
223196
224197
225- def _is_tree_node (x ):
226- """Check if x is a TreeNode."""
198+ def _is_tree_node (x : Any ) -> bool :
227199 return isinstance (x , NodeStates )
228200
229201
@@ -243,20 +215,7 @@ def tree_to_graph(
243215 is_leaf : Callable [[Leaf ], bool ] = _is_tree_node ,
244216 map_non_graph_nodes : bool = False ,
245217) -> Any :
246- """
247- Convert a tree of TreeNode objects to a tree of pytree objects.
248-
249- Args:
250- tree: A tree of TreeNode objects.
251- prefix: A tree of prefixes.
252- merge_fn: A function that merges a TreeNode object.
253- is_node_leaf: A function that checks if a leaf is a TreeNode.
254- is_leaf: A function that checks if a leaf is a TreeNode.
255- map_non_graph_nodes: A boolean indicating whether to map non-graph nodes.
256-
257- Returns:
258- A tree of pytree objects.
259- """
218+ """Convert a pytree of NodeStates back into graph nodes."""
260219 _prefix_is_leaf = lambda x : x is None or is_leaf (x )
261220 leaf_prefixes = broadcast_prefix (prefix , tree , prefix_is_leaf = _prefix_is_leaf , tree_is_leaf = is_leaf )
262221 leaf_keys , treedef = jax .tree_util .tree_flatten_with_path (tree , is_leaf = is_leaf )
@@ -266,13 +225,10 @@ def tree_to_graph(
266225 leaves_out = []
267226 for (keypath , leaf ), leaf_prefix in zip (leaf_keys , leaf_prefixes ):
268227 if is_node_leaf (leaf ):
269- leaf_out = merge_fn (ctx , keypath , leaf_prefix , leaf )
270- leaves_out .append (leaf_out )
228+ leaves_out .append (merge_fn (ctx , keypath , leaf_prefix , leaf ))
271229 else :
272230 if map_non_graph_nodes :
273231 leaf = merge_fn (ctx , keypath , leaf_prefix , leaf )
274232 leaves_out .append (leaf )
275233
276- find_states = states (index_ref )
277- pytree_out = jax .tree .unflatten (treedef , leaves_out )
278- return pytree_out
234+ return jax .tree .unflatten (treedef , leaves_out )
0 commit comments