Skip to content

Commit cc1b84d

Browse files
authored
Simplify JAX compat: use jax.make_jaxpr and aval helpers (#137)
* Refactor import statements for clarity and update jax version checks; enhance error callback functionality * Refactor to use compatible import for get_aval in _loop_collect_return.py * Add mapped_aval import and update references in loop_collect_return.py * Refactor _compatible_import.py and _make_jaxpr.py: remove unused imports and functions for clarity * chore: update jax dependency version to >=0.6.0 in pyproject.toml and requirements.txt * refactor(test): remove test_all_exports for clarity and maintainability * refactor(test): remove test_function_imports_availability for clarity
1 parent be6d356 commit cc1b84d

File tree

8 files changed

+38
-287
lines changed

8 files changed

+38
-287
lines changed

brainstate/_compatible_import.py

Lines changed: 14 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@
4242
>>> # These imports work across different JAX versions
4343
"""
4444

45-
from contextlib import contextmanager
4645
from functools import partial
47-
from typing import Iterable, Hashable, TypeVar, Callable
46+
from typing import Iterable, TypeVar, Callable
4847

4948
import jax
5049
from jax.core import Tracer
@@ -76,9 +75,9 @@
7675
# others
7776
'is_jit_primitive',
7877
'Primitive',
79-
'extend_axis_env_nd',
8078
'jaxpr_as_fun',
8179
'get_aval',
80+
'mapped_aval',
8281
'to_concrete_aval',
8382
'Device',
8483
'wrap_init',
@@ -105,12 +104,15 @@ def get_aval(x):
105104
else:
106105
from jax import Device
107106

107+
if jax.__version_info__ < (0, 8, 2):
108+
from jax.core import mapped_aval
109+
else:
110+
from jax.extend.core import mapped_aval
108111
if jax.__version_info__ < (0, 8, 0):
109112
from jax.lib.xla_bridge import get_backend
110113
else:
111114
from jax.extend.backend import get_backend
112115

113-
114116
if jax.__version_info__ < (0, 7, 1):
115117
from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
116118
else:
@@ -119,40 +121,15 @@ def get_aval(x):
119121
from jax.core import DropVar
120122

121123
if jax.__version_info__ < (0, 4, 38):
122-
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
123-
from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
124+
from jax.core import (
125+
extend_axis_env_nd, jaxpr_as_fun,
126+
ClosedJaxpr, Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
127+
)
124128
else:
125-
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
126-
from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
127-
from jax.core import trace_ctx
128-
129-
130-
@contextmanager
131-
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
132-
"""
133-
Context manager to temporarily extend the JAX axis environment.
134-
135-
Extends the current JAX axis environment with new named axes for
136-
vectorized computations, then restores the previous environment.
137-
138-
Args:
139-
name_size_pairs: Iterable of (name, size) tuples specifying
140-
the named axes to add to the environment.
141-
142-
Yields:
143-
None: Context with extended axis environment.
144-
145-
Examples:
146-
>>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
147-
... # Code using vectorized operations with named axes
148-
... pass
149-
"""
150-
prev = trace_ctx.axis_env
151-
try:
152-
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
153-
yield
154-
finally:
155-
trace_ctx.set_axis_env(prev)
129+
from jax.extend.core import (
130+
ClosedJaxpr, jaxpr_as_fun,
131+
Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
132+
)
156133

157134
if jax.__version_info__ < (0, 6, 0):
158135
from jax.util import safe_map, safe_zip, unzip2, wraps
@@ -257,30 +234,6 @@ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]
257234

258235

259236
def fun_name(fun: Callable):
260-
"""
261-
Extract the name of a function, handling special cases.
262-
263-
Attempts to get the name of a function, with special handling for
264-
partial functions and fallback for unnamed functions.
265-
266-
Args:
267-
fun: The function to get the name from.
268-
269-
Returns:
270-
str: The function name, or "<unnamed function>" if no name available.
271-
272-
Examples:
273-
>>> def my_function():
274-
... pass
275-
>>> fun_name(my_function)
276-
'my_function'
277-
278-
>>> from functools import partial
279-
>>> add = lambda x, y: x + y
280-
>>> add_one = partial(add, 1)
281-
>>> fun_name(add_one)
282-
'<lambda>'
283-
"""
284237
name = getattr(fun, "__name__", None)
285238
if name is not None:
286239
return name

brainstate/_compatible_import_test.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -71,34 +71,6 @@ def test_core_imports_availability(self):
7171
f"{type_name} should be available")
7272
self.assertIsNotNone(getattr(compat, type_name))
7373

74-
def test_function_imports_availability(self):
75-
"""Test function imports are available."""
76-
functions = [
77-
'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
78-
'extend_axis_env_nd'
79-
]
80-
81-
for func_name in functions:
82-
self.assertTrue(hasattr(compat, func_name),
83-
f"{func_name} should be available")
84-
self.assertTrue(callable(getattr(compat, func_name)),
85-
f"{func_name} should be callable")
86-
87-
def test_extend_axis_env_nd_functionality(self):
88-
"""Test extend_axis_env_nd context manager."""
89-
# Test basic functionality
90-
with compat.extend_axis_env_nd([('test_axis', 10)]):
91-
# Context should execute without error
92-
pass
93-
94-
# Test with multiple axes
95-
with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
96-
pass
97-
98-
# Test with empty axes
99-
with compat.extend_axis_env_nd([]):
100-
pass
101-
10274
def test_get_aval_functionality(self):
10375
"""Test get_aval function works correctly."""
10476
# Test with JAX array
@@ -650,21 +622,6 @@ def test_function(x):
650622
class TestModuleStructure(unittest.TestCase):
651623
"""Test module structure and __all__ exports."""
652624

653-
def test_all_exports(self):
654-
"""Test that __all__ contains expected exports."""
655-
expected_exports = [
656-
'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
657-
'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
658-
'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
659-
'Jaxpr', 'Literal'
660-
]
661-
662-
for export in expected_exports:
663-
self.assertIn(export, compat.__all__,
664-
f"{export} should be in __all__")
665-
self.assertTrue(hasattr(compat, export),
666-
f"{export} should be available in module")
667-
668625
def test_no_unexpected_exports(self):
669626
"""Test that no private functions are exported."""
670627
for name in compat.__all__:

brainstate/transform/_error_if.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
def _err_jit_true_branch(err_fun, args, kwargs):
31-
jax.debug.callback(err_fun, *args, **kwargs)
31+
jax.debug.callback(err_fun, *args, **kwargs, ordered=True)
3232

3333

3434
def _err_jit_false_branch(args, kwargs):

brainstate/transform/_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import functools
1717
from collections.abc import Iterable, Sequence
18-
from typing import (Any, Callable, Union)
18+
from typing import Callable, Union
1919

2020
import jax
2121
from jax._src import sharding_impls

brainstate/transform/_loop_collect_return.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax
2121
import jax.numpy as jnp
2222

23+
from brainstate._compatible_import import get_aval, mapped_aval
2324
from brainstate._utils import set_module_as
2425
from ._make_jaxpr import StatefulFunction
2526
from ._progress_bar import ProgressBar
@@ -242,8 +243,8 @@ def scan(
242243

243244
# evaluate jaxpr, get all states #
244245
# ------------------------------ #
245-
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
246-
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
246+
xs_avals = [get_aval(x) for x in xs_flat]
247+
x_avals = [mapped_aval(length, 0, aval) for aval in xs_avals]
247248
args = [init, xs_tree.unflatten(x_avals)]
248249
stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
249250
state_trace = stateful_fun.get_state_trace(*args)
@@ -381,8 +382,8 @@ def checkpointed_scan(
381382
pbar_runner = None
382383

383384
# evaluate jaxpr
384-
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
385-
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
385+
xs_avals = [get_aval(x) for x in xs_flat]
386+
x_avals = [mapped_aval(length, 0, aval) for aval in xs_avals]
386387
args = (init, xs_tree.unflatten(x_avals))
387388
stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
388389
state_trace = stateful_fun.get_state_trace(*args)

0 commit comments

Comments
 (0)