Skip to content

Commit 2d805a3

Browse files
committed
Bump version to 0.3.0 and update changelog with new features, breaking changes, and improvements
1 parent 3a4a731 commit 2d805a3

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

brainstate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
A ``State``-based Transformation System for Program Compilation and Augmentation
1818
"""
1919

20-
__version__ = "0.2.10"
20+
__version__ = "0.3.0"
2121
__versio_info__ = tuple(map(int, __version__.split('.')))
2222

2323
# Register brainstate with JAX's traceback filter so that brainstate internal

changelog.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,74 @@
11
# Release Notes
22

33

4+
## Version 0.3.0
5+
6+
This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.
7+
8+
### Breaking Changes
9+
10+
- **Python >= 3.11 required**: Dropped support for Python 3.10. The `requires-python` field and classifiers now start at 3.11.
11+
- **JAX >= 0.6.0 required**: All dependency groups (`cpu`, `cuda12`, `cuda13`, `tpu`, `testing`) now mandate `jax>=0.6.0`.
12+
- **Unified compilation cache in `StatefulFunction`**: The four separate internal caches (`_cached_jaxpr`, `_cached_out_shapes`, `_cached_jaxpr_out_tree`, `_cached_state_trace`) have been consolidated into a single `_compilation_cache` storing `_CachedCompilation` objects. `get_cache_stats()` now returns `{'compilation_cache': {...}}` instead of four individual entries.
13+
- **Immutable `CacheKey` replaces `hashabledict`**: `get_arg_cache_key()` now returns a `CacheKey` (NamedTuple) instead of the mutable `hashabledict`. Code that directly inspected or constructed cache keys must be updated.
14+
- **Removed internal `_make_jaxpr` function**: The custom tracing implementation has been deleted in favor of using `jax.make_jaxpr()` directly (available in JAX >= 0.6.0).
15+
- **Removed `debug_depth` and `debug_context` from `GradientTransform`**: The `depth` and `context` parameters for NaN debugging no longer exist following the debug module rewrite.
16+
- **Removed `breakpoint_if` function**: The conditional breakpoint helper has been removed from `brainstate.transform._debug`.
17+
- **Removed `extend_axis_env_nd` from compatible imports**: This compatibility shim is no longer exported.
18+
19+
### New Features
20+
21+
#### On-Device NaN/Inf Detection
22+
23+
- Complete rewrite of the NaN debugging system (`brainstate.transform._debug`). NaN checking now runs **on-device** via JAX primitives rather than pulling data to the host, providing significantly better performance.
24+
- Uses `jax.debug.callback` with thread-local storage to collect and report NaN findings.
25+
- Error tracebacks now point to the **user's source code** via `source_info_util.user_context`, producing IDE-clickable source locations extracted from jaxpr equations.
26+
- Recursive instrumentation of nested primitives (`jit`, `cond`, `while`, `scan`) for comprehensive NaN detection throughout the computation graph.
27+
- More compact and informative error messages via `_format_nan_message()`.
28+
29+
#### JAX Traceback Filtering
30+
31+
- Registered brainstate with JAX's `traceback_util.register_exclusion()` so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.
32+
- Users can still see full tracebacks via `JAX_TRACEBACK_FILTERING=off`.
33+
34+
#### State Validation at Call Time
35+
36+
- New `_validate_state_shapes()` method checks that current state shapes and dtypes match those recorded at compile time.
37+
- `StatefulFunction.__call__()` automatically validates before execution, catching state shape mismatches early with clear error messages.
38+
- Added `static_argnums` bounds validation — `make_jaxpr()` now raises `ValueError` if indices exceed the number of positional arguments.
39+
40+
#### New Compatible Import
41+
42+
- Added `mapped_aval` import with version-based routing: `jax.core.mapped_aval` for JAX < 0.8.2, `jax.extend.core.mapped_aval` for >= 0.8.2.
43+
44+
### Improvements
45+
46+
- **Atomic cache writes**: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
47+
- **Better cache key hashing**: Dynamic args/kwargs are now flattened via `jax.tree.flatten()` before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., `Quantity`).
48+
- **Modern Python type annotations**: Migrated from `typing.Tuple`, `typing.List`, `typing.Dict`, `typing.Optional`, `typing.Union` to built-in `tuple`, `list`, `dict`, `X | None`, `X | Y` syntax across the codebase.
49+
- **IR visualization compatibility**: Replaced direct `jax.core.X` references with compatible imports (`Var`, `ClosedJaxpr`, `Jaxpr`, `JaxprEqn`, `Literal`, `DropVar`) in the IR visualizer.
50+
- **Deterministic error reporting**: `jax.debug.callback` in `_error_if.py` now uses `ordered=True` for deterministic error callback ordering.
51+
- **Graph operations cleanup**: Major refactoring of `_operation.py`, `_node.py`, `_convert.py`, and `_context.py` with streamlined docstrings, better thread-safety documentation, and cleaner context managers.
52+
53+
### Bug Fixes
54+
55+
- **Fixed `Delay.__init__` initialization order**: `update_every` is now initialized before `register_entry` is called, preventing attribute errors during entry registration (#135).
56+
- **Fixed `graph_to_tree` private attribute access**: Replaced internal `_mapping` access with public API usage in `_convert.py`.
57+
58+
### Internal Changes
59+
60+
- Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
61+
- Cleaned up TypeVar usage: removed unused `C` and `Names` aliases, renamed `Node` TypeVar to `N`, removed `Hashable` bound from type variables.
62+
- Removed unused tests (`test_all_exports`, `test_function_imports_availability`) from compatible import tests.
63+
- Rewrote debug and make_jaxpr test suites to match the new APIs.
64+
- IR optimization imports are now lazy-loaded inside `make_jaxpr()` only when `ir_optimizations` is configured.
65+
66+
### CI/CD
67+
68+
- Bumped `actions/upload-artifact` from v6 to v7.
69+
- Bumped `actions/download-artifact` from v7 to v8.
70+
71+
472
## Version 0.2.10
573

674
This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.

0 commit comments

Comments
 (0)