Skip to content
9 changes: 9 additions & 0 deletions brainstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
__version__ = "0.2.10"
__versio_info__ = tuple(map(int, __version__.split('.')))

# Register brainstate with JAX's traceback filter so that brainstate internal
# frames are hidden in user-facing error tracebacks. This is the same pattern
# used by Flax, Equinox, and other JAX ecosystem libraries. Users can still
# see the full traceback by setting JAX_TRACEBACK_FILTERING=off.
import os as _os
from jax._src import traceback_util as _traceback_util
_traceback_util.register_exclusion(_os.path.dirname(_os.path.abspath(__file__)))
del _os, _traceback_util

from . import environ
from . import graph
from . import mixin
Expand Down
Loading
Loading