Skip to content

Commit dbbcf4b

Browse files
authored
Refactor graph ops, update JAX/Python requirements, improve tests (#138)
* Refactor import statements for clarity and update jax version checks; enhance error callback functionality * Refactor to use compatible import for get_aval in _loop_collect_return.py * Add mapped_aval import and update references in loop_collect_return.py * Refactor _compatible_import.py and _make_jaxpr.py: remove unused imports and functions for clarity * chore: update jax dependency version to >=0.6.0 in pyproject.toml and requirements.txt * refactor(test): remove test_all_exports for clarity and maintainability * refactor(test): remove test_function_imports_availability for clarity * chore: update Python version requirements to >=3.11 and adjust documentation build settings * refactor: remove type hints for filters in treefy_split, nodes, and states functions
1 parent cc1b84d commit dbbcf4b

File tree

8 files changed

+506
-1522
lines changed

8 files changed

+506
-1522
lines changed

.readthedocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ version: 2
88
build:
99
os: "ubuntu-20.04"
1010
tools:
11-
python: "3.10"
11+
python: "3.13"
1212

1313
# Build documentation in the docs/ directory with Sphinx
1414
sphinx:

brainstate/graph/_context.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
# ==============================================================================
1817

1918
from __future__ import annotations
2019

2120
import contextlib
2221
import dataclasses
2322
import threading
24-
from typing import (Any, Tuple, List)
23+
from typing import Any
2524

2625
from typing_extensions import Unpack
2726

@@ -45,38 +44,37 @@
4544

4645
@dataclasses.dataclass
4746
class GraphContext(threading.local):
47+
"""Thread-local stacks of active split/merge contexts.
48+
49+
Inheriting from ``threading.local`` ensures each thread has its own
50+
independent context stacks, making nested transforms safe under parallelism.
4851
"""
49-
A context manager for handling complex state updates.
50-
"""
51-
ref_index_stack: List[SplitContext] = dataclasses.field(default_factory=list)
52-
index_ref_stack: List[MergeContext] = dataclasses.field(default_factory=list)
52+
53+
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
54+
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
5355

5456

5557
GRAPH_CONTEXT = GraphContext()
5658

5759

5860
@dataclasses.dataclass
5961
class SplitContext:
60-
"""
61-
A context manager for handling graph splitting.
62-
"""
62+
"""Context for splitting graph nodes, tracking shared references."""
63+
6364
ref_index: RefMap[Any, Index]
6465

65-
def treefy_split(self, node: A, *filters: Filter) -> Tuple[GraphDef[A], Unpack[Tuple[NestedDict, ...]]]:
66+
def treefy_split(self, node: A, *filters: Filter) -> tuple[GraphDef[A], Unpack[tuple[NestedDict, ...]]]:
6667
graphdef, statetree = flatten(node, self.ref_index)
6768
state_mappings = _split_state(statetree, filters)
6869
return graphdef, *state_mappings
6970

7071

7172
@contextlib.contextmanager
7273
def split_context():
73-
"""
74-
A context manager for handling graph splitting.
75-
"""
74+
"""Context manager for splitting multiple graph nodes sharing a reference index."""
7675
index_ref: RefMap[Any, Index] = RefMap()
7776
flatten_ctx = SplitContext(index_ref)
7877
GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx)
79-
8078
try:
8179
yield flatten_ctx, index_ref
8280
finally:
@@ -86,32 +84,27 @@ def split_context():
8684

8785
@dataclasses.dataclass
8886
class MergeContext:
89-
"""
90-
A context manager for handling graph merging.
91-
"""
87+
"""Context for merging graph nodes, tracking shared references."""
88+
9289
index_ref: dict[Index, Any]
9390

9491
def treefy_merge(
9592
self,
9693
graphdef: GraphDef[A],
9794
state_mapping: NestedDict,
9895
/,
99-
*state_mappings: NestedDict
96+
*state_mappings: NestedDict,
10097
) -> A:
10198
state_mapping = NestedDict.merge(state_mapping, *state_mappings)
102-
node = unflatten(graphdef, state_mapping, index_ref=self.index_ref)
103-
return node
99+
return unflatten(graphdef, state_mapping, index_ref=self.index_ref)
104100

105101

106102
@contextlib.contextmanager
107103
def merge_context():
108-
"""
109-
A context manager for handling graph merging.
110-
"""
104+
"""Context manager for merging multiple graph nodes sharing a reference index."""
111105
index_ref: dict[Index, Any] = {}
112106
unflatten_ctx = MergeContext(index_ref)
113107
GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx)
114-
115108
try:
116109
yield unflatten_ctx, dict(unflatten_ctx.index_ref)
117110
finally:

brainstate/graph/_convert.py

Lines changed: 45 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
# ==============================================================================
151
# The file is adapted from the Flax library (https://github.com/google/flax).
162
# The credit should go to the Flax authors.
173
#
@@ -28,9 +14,11 @@
2814
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2915
# See the License for the specific language governing permissions and
3016
# limitations under the License.
31-
# ==============================================================================
3217

33-
from typing import Any, Callable, Iterable, TypeVar, Hashable, Optional, Tuple, List, Dict
18+
from __future__ import annotations
19+
20+
from collections.abc import Callable, Iterable
21+
from typing import Any, TypeVar
3422

3523
import jax
3624

@@ -54,8 +42,7 @@
5442

5543
Node = TypeVar('Node')
5644
Leaf = TypeVar('Leaf')
57-
58-
KeyEntry = TypeVar('KeyEntry', bound=Hashable)
45+
KeyEntry = TypeVar('KeyEntry')
5946
KeyPath = tuple[KeyEntry, ...]
6047
Prefix = Any
6148
RandomState = None
@@ -70,15 +57,15 @@ def _get_rand_state() -> type:
7057

7158

7259
def check_consistent_aliasing(
73-
node: Tuple[Any, ...],
74-
prefix: Tuple[Any, ...],
60+
node: tuple[Any, ...],
61+
prefix: tuple[Any, ...],
7562
/,
7663
*,
77-
node_prefixes: Optional[RefMap[Any, List[Tuple[PathParts, Any]]]] = None,
78-
):
64+
node_prefixes: RefMap[Any, list[tuple[PathParts, Any]]] | None = None,
65+
) -> None:
66+
"""Check that shared nodes have consistent prefixes across all paths."""
7967
node_prefixes = RefMap() if node_prefixes is None else node_prefixes
8068

81-
# collect all paths and prefixes for each node
8269
for path, value in iter_graph(node):
8370
if _is_graph_node(value) or isinstance(value, State):
8471
if isinstance(value, GraphNode):
@@ -90,52 +77,34 @@ def check_consistent_aliasing(
9077
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
9178
)
9279
if value in node_prefixes:
93-
paths_prefixes = node_prefixes[value]
94-
paths_prefixes.append((path, prefix))
80+
node_prefixes[value].append((path, prefix))
9581
else:
9682
node_prefixes[value] = [(path, prefix)]
9783

98-
# check for inconsistent aliasing
9984
node_msgs = []
10085
for node, paths_prefixes in node_prefixes.items():
10186
unique_prefixes = {prefix for _, prefix in paths_prefixes}
10287
if len(unique_prefixes) > 1:
103-
path_prefix_repr = '\n'.join([f' {"/".join(map(str, path)) if path else "<root>"}: {prefix}'
104-
for path, prefix in paths_prefixes])
105-
nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}'
106-
node_msgs.append(nodes_msg)
88+
path_prefix_repr = '\n'.join([
89+
f' {"/".join(map(str, path)) if path else "<root>"}: {prefix}'
90+
for path, prefix in paths_prefixes
91+
])
92+
node_msgs.append(f'Node: {type(node)}\n{path_prefix_repr}')
10793

10894
if node_msgs:
109-
raise ValueError('Inconsistent aliasing detected. The '
110-
'following nodes have different prefixes:\n'
111-
+ '\n'.join(node_msgs))
95+
raise ValueError(
96+
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
97+
+ '\n'.join(node_msgs)
98+
)
11299

113100

114-
# -----------------------------
115-
# to_tree/from_tree
116-
# -----------------------------
117-
118101
def broadcast_prefix(
119102
prefix_tree: Any,
120103
full_tree: Any,
121-
prefix_is_leaf: Optional[Callable[[Any], bool]] = None,
122-
tree_is_leaf: Optional[Callable[[Any], bool]] = None,
123-
) -> List[Any]:
124-
"""
125-
Broadcasts a prefix tree to a full tree.
126-
127-
Args:
128-
prefix_tree: A prefix tree.
129-
full_tree: A full tree.
130-
prefix_is_leaf: A function that checks if a prefix is a leaf.
131-
tree_is_leaf: A function that checks if a tree is a leaf.
132-
133-
Returns:
134-
A list of prefixes.
135-
"""
136-
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
137-
# ValueError; use prefix_errors to find disagreements and raise more precise
138-
# error messages.
104+
prefix_is_leaf: Callable[[Any], bool] | None = None,
105+
tree_is_leaf: Callable[[Any], bool] | None = None,
106+
) -> list[Any]:
107+
"""Broadcast a prefix tree to match the leaves of a full tree."""
139108
result = []
140109
num_leaves = lambda t: jax.tree_util.tree_structure(t, is_leaf=tree_is_leaf).num_leaves
141110
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
@@ -144,6 +113,12 @@ def broadcast_prefix(
144113

145114

146115
class NodeStates(PyTreeNode):
116+
"""A JAX pytree wrapper that carries both a GraphDef and one or more state mappings.
117+
118+
Used by ``graph_to_tree`` / ``tree_to_graph`` to represent graph nodes as
119+
pure pytrees so that JAX transforms (vmap, jit, etc.) can operate on them.
120+
"""
121+
147122
_graphdef: GraphDef[Any] | None
148123
states: tuple[GraphStateMapping, ...]
149124
metadata: Any = field(pytree_node=False)
@@ -168,19 +143,19 @@ def from_split(
168143
/,
169144
*states: GraphStateMapping,
170145
metadata: Any = None,
171-
):
146+
) -> NodeStates:
172147
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)
173148

174149
@classmethod
175-
def from_states(cls, state: GraphStateMapping, *states: GraphStateMapping):
150+
def from_states(cls, state: GraphStateMapping, *states: GraphStateMapping) -> NodeStates:
176151
return cls(_graphdef=None, states=(state, *states), metadata=None)
177152

178153
@classmethod
179-
def from_prefixes(cls, prefixes: Iterable[Any], /, *, metadata: Any = None):
154+
def from_prefixes(cls, prefixes: Iterable[Any], /, *, metadata: Any = None) -> NodeStates:
180155
return cls(_graphdef=None, states=tuple(prefixes), metadata=metadata)
181156

182157

183-
def _default_split_fn(ctx: SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf):
158+
def _default_split_fn(ctx: SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf) -> NodeStates:
184159
return NodeStates.from_split(*ctx.treefy_split(leaf))
185160

186161

@@ -192,20 +167,16 @@ def graph_to_tree(
192167
split_fn: Callable[[SplitContext, KeyPath, Prefix, Leaf], Any] = _default_split_fn,
193168
map_non_graph_nodes: bool = False,
194169
check_aliasing: bool = True,
195-
) -> Tuple[PyTree, Dict[KeyPath, SeedOrKey]]:
196-
"""
197-
Convert a tree of pytree objects to a tree of TreeNode objects.
198-
"""
170+
) -> tuple[PyTree, dict[KeyPath, SeedOrKey]]:
171+
"""Convert a pytree that may contain graph nodes into a pure pytree of NodeStates."""
199172
leaf_prefixes = broadcast_prefix(prefix, may_have_graph_nodes, prefix_is_leaf=lambda x: x is None)
200173
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(may_have_graph_nodes)
201174

202-
# Check that the number of keys and prefixes match
203175
assert len(leaf_keys) == len(leaf_prefixes)
204176

205-
# Split the tree
206177
with split_context() as (ctx, index_ref):
207178
leaves_out = []
208-
node_prefixes = RefMap[Any, list[tuple[PathParts, Any]]]()
179+
node_prefixes: RefMap[Any, list[tuple[PathParts, Any]]] = RefMap()
209180
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
210181
if _is_graph_node(leaf):
211182
if check_aliasing:
@@ -215,15 +186,16 @@ def graph_to_tree(
215186
if map_non_graph_nodes:
216187
leaf = split_fn(ctx, keypath, leaf_prefix, leaf)
217188
leaves_out.append(leaf)
218-
pass
219189

220-
find_states = states(index_ref._mapping)
190+
# Build a dict mirroring RefMap's content via the public API, then extract
191+
# State objects from it. We must not access the private ._mapping attribute.
192+
public_map = {id(k): (k, v) for k, v in index_ref.items()}
193+
find_states = states(public_map)
221194
pytree_out = jax.tree.unflatten(treedef, leaves_out)
222195
return pytree_out, find_states
223196

224197

225-
def _is_tree_node(x):
226-
"""Check if x is a TreeNode."""
198+
def _is_tree_node(x: Any) -> bool:
227199
return isinstance(x, NodeStates)
228200

229201

@@ -243,20 +215,7 @@ def tree_to_graph(
243215
is_leaf: Callable[[Leaf], bool] = _is_tree_node,
244216
map_non_graph_nodes: bool = False,
245217
) -> Any:
246-
"""
247-
Convert a tree of TreeNode objects to a tree of pytree objects.
248-
249-
Args:
250-
tree: A tree of TreeNode objects.
251-
prefix: A tree of prefixes.
252-
merge_fn: A function that merges a TreeNode object.
253-
is_node_leaf: A function that checks if a leaf is a TreeNode.
254-
is_leaf: A function that checks if a leaf is a TreeNode.
255-
map_non_graph_nodes: A boolean indicating whether to map non-graph nodes.
256-
257-
Returns:
258-
A tree of pytree objects.
259-
"""
218+
"""Convert a pytree of NodeStates back into graph nodes."""
260219
_prefix_is_leaf = lambda x: x is None or is_leaf(x)
261220
leaf_prefixes = broadcast_prefix(prefix, tree, prefix_is_leaf=_prefix_is_leaf, tree_is_leaf=is_leaf)
262221
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(tree, is_leaf=is_leaf)
@@ -266,13 +225,10 @@ def tree_to_graph(
266225
leaves_out = []
267226
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
268227
if is_node_leaf(leaf):
269-
leaf_out = merge_fn(ctx, keypath, leaf_prefix, leaf)
270-
leaves_out.append(leaf_out)
228+
leaves_out.append(merge_fn(ctx, keypath, leaf_prefix, leaf))
271229
else:
272230
if map_non_graph_nodes:
273231
leaf = merge_fn(ctx, keypath, leaf_prefix, leaf)
274232
leaves_out.append(leaf)
275233

276-
find_states = states(index_ref)
277-
pytree_out = jax.tree.unflatten(treedef, leaves_out)
278-
return pytree_out
234+
return jax.tree.unflatten(treedef, leaves_out)

0 commit comments

Comments
 (0)