Skip to content

Add on-device NaN debugging and unify StatefulFunction cache#139

Merged
chaoming0625 merged 7 commits intomainfrom
update
Mar 9, 2026
Merged

Add on-device NaN debugging and unify StatefulFunction cache#139
chaoming0625 merged 7 commits intomainfrom
update

Conversation

@chaoming0625
Copy link
Member

@chaoming0625 chaoming0625 commented Mar 9, 2026

Summary by Sourcery

Introduce a new on-device NaN/Inf debugging pipeline integrated with stateful JAX execution and simplify caching, while updating tests and examples accordingly.

New Features:

  • Add a JIT-compatible DebugNan utility and debug_nan/debug_nan_if helpers that detect NaN/Inf on-device and report precise, IDE-clickable source locations.
  • Provide conditional breakpoint_if helper that gates jax.debug.breakpoint on a predicate.

Enhancements:

  • Refactor NaN debugging implementation to avoid CPU fallback, leverage thread-local storage and source_info_util for clean error reporting, and support nested primitives including jit, cond, while, and scan.
  • Redesign StatefulFunction compilation caching into a single unified compilation cache keyed by an immutable CacheKey, with improved cache statistics and error messages.
  • Add state shape/dtype validation at call time to catch mismatches since compilation, and support optional IR optimizations in StatefulFunction.make_jaxpr.
  • Strengthen _make_jaxpr hashing utilities to use a safer CacheKey/_make_hashable implementation and improve handling of unhashable objects.

Documentation:

  • Clarify StatefulFunction and make_jaxpr docstrings around cache behavior, return_only_write semantics, IR optimizations, and state validation.
  • Update debug_nan-related docstrings to describe the new on-device behavior and error reporting semantics.

Tests:

  • Replace and expand _debug tests to cover the new NaN detection pipeline, source extraction, nested control-flow primitives, and DebugNan class behavior.
  • Update _make_jaxpr tests for the unified compilation cache, CacheKey immutability, error reporting, state shape validation, and IR optimization paths.

Chores:

  • Simplify the gradient transform debug_nan integration by removing unused depth/context parameters.
  • Fix imports in mapping utilities to use BoundedCache from the new location and clean up an unused brainstate init export.

@sourcery-ai
Copy link
Contributor

sourcery-ai bot commented Mar 9, 2026

Reviewer's Guide

Refactors NaN/Inf debugging to run entirely on-device with a new jaxpr interpreter and callback strategy, and overhauls StatefulFunction’s compilation caching into a single immutable-keyed cache with state-shape validation and optional IR optimizations, updating tests, gradient transform, mapping utilities, and examples accordingly.

Sequence diagram for on_device NaN detection with DebugNan.check

sequenceDiagram
    actor User
    participant DebugNan
    participant StatefulFunction
    participant _interpret_jaxpr_with_nan_check
    participant JaxPrimitive as PrimitiveEqn
    participant JaxCallback as jax_debug_callback
    participant NanStore as _nan_store
    participant SourceInfoUtil as source_info_util

    User->>DebugNan: debug_nan(fn, *args, phase)
    DebugNan->>StatefulFunction: __init__(fn, *args)
    StatefulFunction-->>DebugNan: ClosedJaxpr(jaxpr, consts), states

    User->>DebugNan: check()
    DebugNan->>DebugNan: _build_flat_args()
    DebugNan->>NanStore: store_snapshot = len(_nan_store_get())
    DebugNan->>_interpret_jaxpr_with_nan_check: (jaxpr, consts, *flat_args, phase)

    loop for each eqn in jaxpr.eqns
        _interpret_jaxpr_with_nan_check->>PrimitiveEqn: execute eqn (may recurse for jit/cond/while/scan)
        PrimitiveEqn-->>_interpret_jaxpr_with_nan_check: invals, outvals
        _interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: input_has_nan = _has_nan_flag(invals)
        _interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: output_has_nan = _has_nan_flag(outvals)
        _interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: nan_introduced = output_has_nan & ~input_has_nan

        alt nan_introduced is True
            _interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: cb = _make_nan_callback(eqn_meta, phase)
            _interpret_jaxpr_with_nan_check->>JaxCallback: jax.debug.callback(cb, *float_invals)
            JaxCallback->>NanStore: _nan_store_get().append((msg, raw_source_info))
        end
    end

    _interpret_jaxpr_with_nan_check-->>DebugNan: outvars

    DebugNan->>DebugNan: _raise_if_nan_detected(store_snapshot)
    DebugNan->>NanStore: records = _nan_store_get()
    alt new NaN records
        NanStore-->>DebugNan: msg, raw_src
        DebugNan->>SourceInfoUtil: user_context(tb, name_stack)
        SourceInfoUtil-->>User: raise RuntimeError(msg)
    else no new records
        DebugNan-->>User: return None (no NaN)
    end
Loading

Class diagram for DebugNan and NaN_debugging helpers

classDiagram
    class DebugNan {
        - fn : Callable
        - phase : str
        - _stateful_fn : StatefulFunction
        - _jaxpr
        - _consts
        - _flat_user_args : list
        - _states : list
        + DebugNan(fn, *args, phase)
        + check() None
        + check_if(has_nan) None
        - _build_flat_args() list
    }

    class _nan_store {
        <<thread_local>>
        + records : list
    }

    class CacheKey {
        <<NamedTuple>>
        + static_args : tuple
        + dyn_args : tuple
        + static_kwargs : tuple
        + dyn_kwargs : tuple
    }

    class _CachedCompilation {
        <<internal>>
        - jaxpr
        - out_shapes
        - out_treedef
        - state_trace : StateTraceStack
        - state_avals : tuple
        + _CachedCompilation(jaxpr, out_shapes, out_treedef, state_trace, state_avals)
    }

    class StatefulFunction {
        - fun : Callable
        - static_argnums : tuple
        - axis_env
        - name : str
        - return_only_write : bool
        - ir_optimizations : tuple
        - _compilation_cache : BoundedCache
        - _cache_lock : RLock
        + StatefulFunction(fun, static_argnums, axis_env, name, return_only_write, ir_optimizations)
        + get_arg_cache_key(*args, **kwargs) CacheKey
        + get_jaxpr_by_cache(cache_key) ClosedJaxpr
        + get_out_shapes_by_cache(cache_key) PyTree
        + get_out_treedef_by_cache(cache_key) PyTree
        + get_state_trace_by_cache(cache_key) StateTraceStack
        + get_jaxpr(*args, **kwargs) ClosedJaxpr
        + get_out_shapes(*args, **kwargs) PyTree
        + get_out_treedef(*args, **kwargs) PyTree
        + get_state_trace(*args, **kwargs) StateTraceStack
        + make_jaxpr(*args, **kwargs) StatefulFunction
        + jaxpr_call(state_vals, *args, **kwargs) Any
        + jaxpr_call_auto(*args, **kwargs) Any
        + debug_call(state_vals, *args, **kwargs) Any
        + validate_states(cache_key) bool
        + validate_all_states() Dict
        + clear_cache() None
        + get_cache_stats() Dict
        + __call__(*args, **kwargs) Any
        - _make_new_arg() Callable
        - _wrapped_fun_to_eval(_result_holder, static_kwargs, *args, **dyn_kwargs)
        - _get_compilation(cache_key) _CachedCompilation
        - _validate_state_shapes(cache_key) None
    }

    class BoundedCache {
        + maxsize : int
        + get(key, raise_on_miss, error_context)
        + set(key, value)
        + clear() None
        + keys() Iterable
        + get_stats() Dict
    }

    class StateTraceStack {
        + name : str
        + states : list
        + original_state_values : list
        + set_new_arg(fn) None
        + get_write_state_values(strict) list
        + get_state_values() list
    }

    class brainstate_transform_debug_module {
        <<module-level functions>>
        + debug_nan(fn, *args, phase) None
        + debug_nan_if(has_nan, fn, *args, phase) None
        + breakpoint_if(pred, **breakpoint_kwargs)
        + _interpret_jaxpr_with_nan_check(jaxpr, consts, *flat_args, phase, raise_in_callback) list
        + _execute_eqn(eqn, invals, phase, raise_in_callback) list
        + _execute_jit_eqn(eqn, invals, phase, raise_in_callback) list
        + _execute_cond_eqn(eqn, invals, phase, raise_in_callback) list
        + _execute_while_eqn(eqn, invals, phase, raise_in_callback) list
        + _execute_scan_eqn(eqn, invals, phase, raise_in_callback) list
        + _make_nan_callback(eqn_idx, total_eqns, prim_name, eqn_str, source_loc, raw_source_info, phase, raise_in_callback)
        + _has_nan_flag(vals) jax.Array
        + _is_float_array(x) bool
        + _raise_if_nan_detected(store_snapshot) None
        + _extract_user_source(source_info) str
    }

    DebugNan --> StatefulFunction : uses for jaxpr compilation
    DebugNan --> _nan_store : writes NaN records via callbacks
    DebugNan --> brainstate_transform_debug_module : calls _interpret_jaxpr_with_nan_check

    StatefulFunction --> CacheKey : returns from get_arg_cache_key
    StatefulFunction --> _CachedCompilation : stores in _compilation_cache
    StatefulFunction --> BoundedCache : uses for compilation cache
    StatefulFunction --> StateTraceStack : tracks State accesses

    _CachedCompilation --> StateTraceStack : holds

    brainstate_transform_debug_module --> _nan_store : appends records
    brainstate_transform_debug_module --> StatefulFunction : via DebugNan
    brainstate_transform_debug_module --> CacheKey : indirectly via StatefulFunction
Loading

File-Level Changes

Change Details Files
Replace previous host-side NaN/Inf analysis with an on-device, jaxpr-interpreting NaN debugger that works under JIT and high-level control-flow primitives.
  • Introduce a thread-local store for NaN reports and helpers to extract user-facing source locations from JAX source_info with traceback filtering.
  • Add helpers for naming and formatting jaxpr vars/equations and for computing an on-device scalar NaN/Inf flag across float arrays.
  • Implement _interpret_jaxpr_with_nan_check and specialized executors for jit, cond, while, and scan primitives that execute via real JAX primitives while inserting NaN checks and debug callbacks.
  • Redesign DebugNan to use StatefulFunction’s jaxpr, track current State values, and expose check/check_if methods that raise clean RuntimeErrors, plus new debug_nan/debug_nan_if convenience wrappers.
  • Move breakpoint_if to the bottom, still implemented via cond/unvmap around jax.debug.breakpoint.
brainstate/transform/_debug.py
Rewrite NaN debugging tests to target the new on-device interpreter, helper functions, and behavior across nested control-flow and stateful models.
  • Replace tests for removed internal helpers with tests for _has_nan_flag, _extract_user_source, and _interpret_jaxpr_with_nan_check.
  • Add coverage for debug_nan/debug_nan_if basic behavior, interaction with jit, cond, while, scan, NaN propagation rules, and DebugNan’s class API.
  • Ensure source extraction omits internal brainstate tracing frames and produces usable locations.
  • Adapt tests to new error messaging and removal of depth/context knobs.
brainstate/transform/_debug_test.py
Unify StatefulFunction compilation artifacts behind an immutable CacheKey and a single bounded cache entry per key, adding IR optimizations and state-shape validation.
  • Introduce CacheKey NamedTuple and _CachedCompilation container, replacing the old hashabledict-based multi-cache design.
  • Refactor StatefulFunction to use a single BoundedCache mapping CacheKey to _CachedCompilation, with helper accessors for jaxpr, output shapes/tree, and state trace.
  • Replace legacy JAX v0.4 tracing helpers with _make_new_arg based on jax.core.trace_ctx, and use a result_holder to capture StateTraceStack without caching it pre-compilation.
  • Extend make_jaxpr to validate static_argnums bounds, optionally optimize jaxprs via brainstate.transform._ir_optim, compute abstract state avals, and store them in the compilation cache.
  • Add state-shape/dtype validation via _validate_state_shapes in call, leaving jaxpr_call_auto unvalidated for internal transforms.
  • Tighten _make_hashable to treat already-hashable objects fast-path and raise TypeError when neither hashable nor a valid pytree.
  • Update get_cache_stats, clear_cache, validate_states, and validate_all_states to work with the unified compilation cache.
brainstate/transform/_make_jaxpr.py
Update make_jaxpr / StatefulFunction tests for the unified cache, new CacheKey type, stricter hashability rules, state validation, and IR optimization options.
  • Adjust cache-related assertions to use the single compilation_cache stats and to construct fake keys via CacheKey.
  • Add tests verifying CacheKey immutability, state shape/dtype mismatch detection on call, and that unchanged states pass validation.
  • Add tests for ir_optimizations argument handling (list, string, None) and for static_argnums out-of-bounds raising ValueError.
  • Add tests ensuring _make_hashable raises TypeError for non-pytree unhashable objects and that cache remains empty after failed compilation.
brainstate/transform/_make_jaxpr_test.py
Simplify GradientTransform’s NaN debugging configuration to rely on the new DebugNan behavior without depth/context controls.
  • Remove debug_depth and debug_context parameters and attributes from GradientTransform.
  • Adjust grad_fn’s debug_nan integration to only pass phase, relying on the new debug_nan signature.
brainstate/transform/_grad_transform.py
Align mapping utilities and examples with the new infrastructure and imports.
  • Change StatefulMapping imports to use BoundedCache from brainstate.util._cache instead of from _make_jaxpr, and keep get_arg_cache_key from _make_jaxpr.
  • Update the NaN gradient example main block to run the standard examples and remove the exp_exprel helper that used the old decorator-style debug_nan behavior.
brainstate/transform/_mapping2.py
examples/011_debug_nan_gradient.py
brainstate/__init__.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Member Author

@sourcery-ai title

@sourcery-ai sourcery-ai bot changed the title Update Add on-device NaN debugging and unify StatefulFunction cache Mar 9, 2026
Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • The thread-local _nan_store list is never cleared and only grows over time; consider truncating it or resetting it after _raise_if_nan_detected (or exposing an explicit reset) to avoid unbounded memory growth in long-running processes.
  • In _extract_user_source, filtering traceback lines with the hard-coded '/site-packages/' substring may fail on Windows paths or nonstandard install layouts; it would be more robust to detect third-party frames via site/sysconfig paths or pathlib-based checks instead of a literal string.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The thread-local `_nan_store` list is never cleared and only grows over time; consider truncating it or resetting it after `_raise_if_nan_detected` (or exposing an explicit reset) to avoid unbounded memory growth in long-running processes.
- In `_extract_user_source`, filtering traceback lines with the hard-coded `'/site-packages/'` substring may fail on Windows paths or nonstandard install layouts; it would be more robust to detect third-party frames via `site`/`sysconfig` paths or `pathlib`-based checks instead of a literal string.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors two key subsystems in brainstate: the NaN debugging pipeline and the StatefulFunction compilation cache. The NaN debugging is rewritten to perform on-device detection using an instrumented jaxpr interpreter (eliminating CPU fallback), while the StatefulFunction cache is unified from four separate caches into a single CacheKey → _CachedCompilation cache with added state shape validation.

Changes:

  • Replaces the CPU-based NaN debugging with an on-device instrumented jaxpr interpreter that detects NaN/Inf introduction per-primitive, reports IDE-clickable source locations, and integrates with JAX's traceback filtering.
  • Unifies StatefulFunction's four separate bounded caches (jaxpr_cache, out_shapes_cache, jaxpr_out_tree_cache, state_trace_cache) into a single _compilation_cache keyed by an immutable CacheKey NamedTuple, with double-check locking and state shape/dtype validation.
  • Simplifies the gradient transform's debug_nan integration by removing unused depth/context parameters and updating imports.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
brainstate/transform/_debug.py Complete rewrite of NaN debugging: on-device detection, thread-local store, source info extraction, instrumented jaxpr interpreter with nested primitive support
brainstate/transform/_make_jaxpr.py Unified compilation cache with CacheKey/_CachedCompilation, state shape validation, IR optimization integration, improved _make_hashable
brainstate/transform/_debug_test.py Replaced and expanded tests for the new NaN detection pipeline
brainstate/transform/_make_jaxpr_test.py Updated tests for unified cache, CacheKey immutability, state shape validation, IR optimizations
brainstate/transform/_grad_transform.py Removed unused debug_depth/debug_context parameters
brainstate/transform/_mapping2.py Fixed BoundedCache import to use brainstate.util._cache
brainstate/__init__.py Registered brainstate with JAX's traceback filter
examples/011_debug_nan_gradient.py Removed exp_exprel example, uncommented all standard examples

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +296 to +307
for eqn, meta in zip(jaxpr.eqns, eqn_meta):
invals = _get_invals(eqn, env)

# On-device: does the input already carry NaN?
input_has_nan = _has_nan_flag(invals)

# Execute the primitive (with recursive instrumentation for nested ones)
outvals = _execute_eqn(eqn, invals, phase, raise_in_callback)

# On-device: did this op introduce new NaN?
output_has_nan = _has_nan_flag(outvals)
nan_introduced = output_has_nan & ~input_has_nan
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For nested high-level primitives (jit, cond, while, scan), NaN is detected both inside the recursively-instrumented inner jaxpr and at the outer level in _interpret_jaxpr_with_nan_check. This happens because _execute_eqn dispatches to _execute_jit_eqn etc. which instruments the inner jaxpr (and fires a callback if NaN is introduced inside), but the outer loop also checks nan_introduced = output_has_nan & ~input_has_nan for the containing equation — which will be True because the inner computation introduced NaN from clean inputs.

This leads to double-reporting of the same NaN: once for the inner primitive (e.g., log) and once for the outer primitive (e.g., pjit). Consider skipping the outer NaN check for expandable/nested primitives, since the inner instrumentation already handles them.

Copilot uses AI. Check for mistakes.
Comment on lines +1154 to 1156
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True)
self._validate_state_shapes(cache_key)
return self.jaxpr_call_auto(*args, **kwargs)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __call__ method computes cache_key via get_arg_cache_key(*args, **kwargs, compile_if_miss=True) and then immediately calls jaxpr_call_auto(*args, **kwargs), which internally computes the exact same cache key again (line 1132). This means every user-facing call redundantly computes the cache key twice. Consider passing the already-computed cache_key to jaxpr_call_auto (or extracting a shared internal method) to avoid the redundant work.

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +254
def _raise_if_nan_detected(store_snapshot: int) -> None:
"""
After an instrumented run, check the thread-local store for new NaN records.

class DebugNan:
Uses ``source_info_util.user_context`` — the same mechanism as
``State.raise_error_with_source_info`` — so that JAX's traceback filtering
shows the *user code* that introduced the NaN rather than library internals.
"""
JIT-compatible NaN/Inf debugging utility.
records = _nan_store_get()
new = records[store_snapshot:]
if not new:
return
msg, raw_src = new[0]
if raw_src is not None:
tb = getattr(raw_src, 'traceback', None)
name_stack = (
source_info_util.current_name_stack()
+ getattr(raw_src, 'name_stack', source_info_util.NameStack())
)
with source_info_util.user_context(tb, name_stack=name_stack):
raise RuntimeError(msg)
raise RuntimeError(msg)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thread-local _nan_store.records list is only ever appended to (line 214) but never cleared or trimmed after records are consumed by _raise_if_nan_detected. In long-running applications or test suites, this list will grow without bound, leaking memory. After raising (or deciding not to raise), the consumed records should be removed, e.g. del records[store_snapshot:] at the end of _raise_if_nan_detected, or after each check() / check_if() call.

Copilot uses AI. Check for mistakes.
@chaoming0625 chaoming0625 merged commit 3a4a731 into main Mar 9, 2026
4 of 7 checks passed
@chaoming0625 chaoming0625 deleted the update branch March 9, 2026 02:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants