|
1 | 1 | # Release Notes |
2 | 2 |
|
3 | 3 |
|
| 4 | +## Version 0.3.0 |
| 5 | + |
| 6 | +This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0. |
| 7 | + |
| 8 | +### Breaking Changes |
| 9 | + |
| 10 | +- **Python >= 3.11 required**: Dropped support for Python 3.10. The `requires-python` field and classifiers now start at 3.11. |
| 11 | +- **JAX >= 0.6.0 required**: All dependency groups (`cpu`, `cuda12`, `cuda13`, `tpu`, `testing`) now mandate `jax>=0.6.0`. |
| 12 | +- **Unified compilation cache in `StatefulFunction`**: The four separate internal caches (`_cached_jaxpr`, `_cached_out_shapes`, `_cached_jaxpr_out_tree`, `_cached_state_trace`) have been consolidated into a single `_compilation_cache` storing `_CachedCompilation` objects. `get_cache_stats()` now returns `{'compilation_cache': {...}}` instead of four individual entries. |
| 13 | +- **Immutable `CacheKey` replaces `hashabledict`**: `get_arg_cache_key()` now returns a `CacheKey` (NamedTuple) instead of the mutable `hashabledict`. Code that directly inspected or constructed cache keys must be updated. |
| 14 | +- **Removed internal `_make_jaxpr` function**: The custom tracing implementation has been deleted in favor of using `jax.make_jaxpr()` directly (available in JAX >= 0.6.0). |
| 15 | +- **Removed `debug_depth` and `debug_context` from `GradientTransform`**: The `depth` and `context` parameters for NaN debugging no longer exist following the debug module rewrite. |
| 16 | +- **Removed `breakpoint_if` function**: The conditional breakpoint helper has been removed from `brainstate.transform._debug`. |
| 17 | +- **Removed `extend_axis_env_nd` from compatible imports**: This compatibility shim is no longer exported. |
| 18 | + |
| 19 | +### New Features |
| 20 | + |
| 21 | +#### On-Device NaN/Inf Detection |
| 22 | + |
| 23 | +- Complete rewrite of the NaN debugging system (`brainstate.transform._debug`). NaN checking now runs **on-device** via JAX primitives rather than pulling data to the host, providing significantly better performance. |
| 24 | +- Uses `jax.debug.callback` with thread-local storage to collect and report NaN findings. |
| 25 | +- Error tracebacks now point to the **user's source code** via `source_info_util.user_context`, producing IDE-clickable source locations extracted from jaxpr equations. |
| 26 | +- Recursive instrumentation of nested primitives (`jit`, `cond`, `while`, `scan`) for comprehensive NaN detection throughout the computation graph. |
| 27 | +- More compact and informative error messages via `_format_nan_message()`. |
| 28 | + |
| 29 | +#### JAX Traceback Filtering |
| 30 | + |
| 31 | +- Registered brainstate with JAX's `traceback_util.register_exclusion()` so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries. |
| 32 | +- Users can still see full tracebacks via `JAX_TRACEBACK_FILTERING=off`. |
| 33 | + |
| 34 | +#### State Validation at Call Time |
| 35 | + |
| 36 | +- New `_validate_state_shapes()` method checks that current state shapes and dtypes match those recorded at compile time. |
| 37 | +- `StatefulFunction.__call__()` automatically validates before execution, catching state shape mismatches early with clear error messages. |
| 38 | +- Added `static_argnums` bounds validation — `make_jaxpr()` now raises `ValueError` if indices exceed the number of positional arguments. |
| 39 | + |
| 40 | +#### New Compatible Import |
| 41 | + |
| 42 | +- Added `mapped_aval` import with version-based routing: `jax.core.mapped_aval` for JAX < 0.8.2, `jax.extend.core.mapped_aval` for >= 0.8.2. |
| 43 | + |
| 44 | +### Improvements |
| 45 | + |
| 46 | +- **Atomic cache writes**: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation. |
| 47 | +- **Better cache key hashing**: Dynamic args/kwargs are now flattened via `jax.tree.flatten()` before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., `Quantity`). |
| 48 | +- **Modern Python type annotations**: Migrated from `typing.Tuple`, `typing.List`, `typing.Dict`, `typing.Optional`, `typing.Union` to built-in `tuple`, `list`, `dict`, `X | None`, `X | Y` syntax across the codebase. |
| 49 | +- **IR visualization compatibility**: Replaced direct `jax.core.X` references with compatible imports (`Var`, `ClosedJaxpr`, `Jaxpr`, `JaxprEqn`, `Literal`, `DropVar`) in the IR visualizer. |
| 50 | +- **Deterministic error reporting**: `jax.debug.callback` in `_error_if.py` now uses `ordered=True` for deterministic error callback ordering. |
| 51 | +- **Graph operations cleanup**: Major refactoring of `_operation.py`, `_node.py`, `_convert.py`, and `_context.py` with streamlined docstrings, better thread-safety documentation, and cleaner context managers. |
| 52 | + |
| 53 | +### Bug Fixes |
| 54 | + |
| 55 | +- **Fixed `Delay.__init__` initialization order**: `update_every` is now initialized before `register_entry` is called, preventing attribute errors during entry registration (#135). |
| 56 | +- **Fixed `graph_to_tree` private attribute access**: Replaced internal `_mapping` access with public API usage in `_convert.py`. |
| 57 | + |
| 58 | +### Internal Changes |
| 59 | + |
| 60 | +- Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions. |
| 61 | +- Cleaned up TypeVar usage: removed unused `C` and `Names` aliases, renamed `Node` TypeVar to `N`, removed `Hashable` bound from type variables. |
| 62 | +- Removed unused tests (`test_all_exports`, `test_function_imports_availability`) from compatible import tests. |
| 63 | +- Rewrote debug and make_jaxpr test suites to match the new APIs. |
| 64 | +- IR optimization imports are now lazy-loaded inside `make_jaxpr()` only when `ir_optimizations` is configured. |
| 65 | + |
| 66 | +### CI/CD |
| 67 | + |
| 68 | +- Bumped `actions/upload-artifact` from v6 to v7. |
| 69 | +- Bumped `actions/download-artifact` from v7 to v8. |
| 70 | + |
| 71 | + |
4 | 72 | ## Version 0.2.10 |
5 | 73 |
|
6 | 74 | This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management. |
|
0 commit comments