Skip to content

Commit 3a4a731

Browse files
authored
Add on-device NaN debugging and unify StatefulFunction cache (#139)
* refactor(debug_nan_gradient): remove unused exp_exprel function and streamline main examples * refactor(debug_nan): remove unused parameters and streamline NaN detection tests * refactor(make_jaxpr): replace mutable hashabledict with immutable CacheKey and unify compilation cache * refactor(traceback): register brainstate with JAX's traceback filter to exclude internal frames * refactor(debug): enhance traceback filtering and improve dynamic argument hashing * refactor(debug_nan_gradient): improve NaN detection output and streamline gradient examples * refactor(debug_test): update NaN detection test to reference the primitive that introduced NaN
1 parent dbbcf4b commit 3a4a731

File tree

8 files changed

+1293
-1814
lines changed

8 files changed

+1293
-1814
lines changed

brainstate/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
__version__ = "0.2.10"
2121
__versio_info__ = tuple(map(int, __version__.split('.')))
2222

23+
# Register brainstate with JAX's traceback filter so that brainstate internal
24+
# frames are hidden in user-facing error tracebacks. This is the same pattern
25+
# used by Flax, Equinox, and other JAX ecosystem libraries. Users can still
26+
# see the full traceback by setting JAX_TRACEBACK_FILTERING=off.
27+
import os as _os
28+
from jax._src import traceback_util as _traceback_util
29+
_traceback_util.register_exclusion(_os.path.dirname(_os.path.abspath(__file__)))
30+
del _os, _traceback_util
31+
2332
from . import environ
2433
from . import graph
2534
from . import mixin

0 commit comments

Comments
 (0)