Add on-device NaN debugging and unify StatefulFunction cache#139
Add on-device NaN debugging and unify StatefulFunction cache#139chaoming0625 merged 7 commits intomainfrom
Conversation
…treamline main examples
…heKey and unify compilation cache
…to exclude internal frames
Reviewer's GuideRefactors NaN/Inf debugging to run entirely on-device with a new jaxpr interpreter and callback strategy, and overhauls StatefulFunction’s compilation caching into a single immutable-keyed cache with state-shape validation and optional IR optimizations, updating tests, gradient transform, mapping utilities, and examples accordingly. Sequence diagram for on_device NaN detection with DebugNan.checksequenceDiagram
actor User
participant DebugNan
participant StatefulFunction
participant _interpret_jaxpr_with_nan_check
participant JaxPrimitive as PrimitiveEqn
participant JaxCallback as jax_debug_callback
participant NanStore as _nan_store
participant SourceInfoUtil as source_info_util
User->>DebugNan: debug_nan(fn, *args, phase)
DebugNan->>StatefulFunction: __init__(fn, *args)
StatefulFunction-->>DebugNan: ClosedJaxpr(jaxpr, consts), states
User->>DebugNan: check()
DebugNan->>DebugNan: _build_flat_args()
DebugNan->>NanStore: store_snapshot = len(_nan_store_get())
DebugNan->>_interpret_jaxpr_with_nan_check: (jaxpr, consts, *flat_args, phase)
loop for each eqn in jaxpr.eqns
_interpret_jaxpr_with_nan_check->>PrimitiveEqn: execute eqn (may recurse for jit/cond/while/scan)
PrimitiveEqn-->>_interpret_jaxpr_with_nan_check: invals, outvals
_interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: input_has_nan = _has_nan_flag(invals)
_interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: output_has_nan = _has_nan_flag(outvals)
_interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: nan_introduced = output_has_nan & ~input_has_nan
alt nan_introduced is True
_interpret_jaxpr_with_nan_check->>_interpret_jaxpr_with_nan_check: cb = _make_nan_callback(eqn_meta, phase)
_interpret_jaxpr_with_nan_check->>JaxCallback: jax.debug.callback(cb, *float_invals)
JaxCallback->>NanStore: _nan_store_get().append((msg, raw_source_info))
end
end
_interpret_jaxpr_with_nan_check-->>DebugNan: outvars
DebugNan->>DebugNan: _raise_if_nan_detected(store_snapshot)
DebugNan->>NanStore: records = _nan_store_get()
alt new NaN records
NanStore-->>DebugNan: msg, raw_src
DebugNan->>SourceInfoUtil: user_context(tb, name_stack)
SourceInfoUtil-->>User: raise RuntimeError(msg)
else no new records
DebugNan-->>User: return None (no NaN)
end
Class diagram for DebugNan and NaN_debugging helpersclassDiagram
class DebugNan {
- fn : Callable
- phase : str
- _stateful_fn : StatefulFunction
- _jaxpr
- _consts
- _flat_user_args : list
- _states : list
+ DebugNan(fn, *args, phase)
+ check() None
+ check_if(has_nan) None
- _build_flat_args() list
}
class _nan_store {
<<thread_local>>
+ records : list
}
class CacheKey {
<<NamedTuple>>
+ static_args : tuple
+ dyn_args : tuple
+ static_kwargs : tuple
+ dyn_kwargs : tuple
}
class _CachedCompilation {
<<internal>>
- jaxpr
- out_shapes
- out_treedef
- state_trace : StateTraceStack
- state_avals : tuple
+ _CachedCompilation(jaxpr, out_shapes, out_treedef, state_trace, state_avals)
}
class StatefulFunction {
- fun : Callable
- static_argnums : tuple
- axis_env
- name : str
- return_only_write : bool
- ir_optimizations : tuple
- _compilation_cache : BoundedCache
- _cache_lock : RLock
+ StatefulFunction(fun, static_argnums, axis_env, name, return_only_write, ir_optimizations)
+ get_arg_cache_key(*args, **kwargs) CacheKey
+ get_jaxpr_by_cache(cache_key) ClosedJaxpr
+ get_out_shapes_by_cache(cache_key) PyTree
+ get_out_treedef_by_cache(cache_key) PyTree
+ get_state_trace_by_cache(cache_key) StateTraceStack
+ get_jaxpr(*args, **kwargs) ClosedJaxpr
+ get_out_shapes(*args, **kwargs) PyTree
+ get_out_treedef(*args, **kwargs) PyTree
+ get_state_trace(*args, **kwargs) StateTraceStack
+ make_jaxpr(*args, **kwargs) StatefulFunction
+ jaxpr_call(state_vals, *args, **kwargs) Any
+ jaxpr_call_auto(*args, **kwargs) Any
+ debug_call(state_vals, *args, **kwargs) Any
+ validate_states(cache_key) bool
+ validate_all_states() Dict
+ clear_cache() None
+ get_cache_stats() Dict
+ __call__(*args, **kwargs) Any
- _make_new_arg() Callable
- _wrapped_fun_to_eval(_result_holder, static_kwargs, *args, **dyn_kwargs)
- _get_compilation(cache_key) _CachedCompilation
- _validate_state_shapes(cache_key) None
}
class BoundedCache {
+ maxsize : int
+ get(key, raise_on_miss, error_context)
+ set(key, value)
+ clear() None
+ keys() Iterable
+ get_stats() Dict
}
class StateTraceStack {
+ name : str
+ states : list
+ original_state_values : list
+ set_new_arg(fn) None
+ get_write_state_values(strict) list
+ get_state_values() list
}
class brainstate_transform_debug_module {
<<module-level functions>>
+ debug_nan(fn, *args, phase) None
+ debug_nan_if(has_nan, fn, *args, phase) None
+ breakpoint_if(pred, **breakpoint_kwargs)
+ _interpret_jaxpr_with_nan_check(jaxpr, consts, *flat_args, phase, raise_in_callback) list
+ _execute_eqn(eqn, invals, phase, raise_in_callback) list
+ _execute_jit_eqn(eqn, invals, phase, raise_in_callback) list
+ _execute_cond_eqn(eqn, invals, phase, raise_in_callback) list
+ _execute_while_eqn(eqn, invals, phase, raise_in_callback) list
+ _execute_scan_eqn(eqn, invals, phase, raise_in_callback) list
+ _make_nan_callback(eqn_idx, total_eqns, prim_name, eqn_str, source_loc, raw_source_info, phase, raise_in_callback)
+ _has_nan_flag(vals) jax.Array
+ _is_float_array(x) bool
+ _raise_if_nan_detected(store_snapshot) None
+ _extract_user_source(source_info) str
}
DebugNan --> StatefulFunction : uses for jaxpr compilation
DebugNan --> _nan_store : writes NaN records via callbacks
DebugNan --> brainstate_transform_debug_module : calls _interpret_jaxpr_with_nan_check
StatefulFunction --> CacheKey : returns from get_arg_cache_key
StatefulFunction --> _CachedCompilation : stores in _compilation_cache
StatefulFunction --> BoundedCache : uses for compilation cache
StatefulFunction --> StateTraceStack : tracks State accesses
_CachedCompilation --> StateTraceStack : holds
brainstate_transform_debug_module --> _nan_store : appends records
brainstate_transform_debug_module --> StatefulFunction : via DebugNan
brainstate_transform_debug_module --> CacheKey : indirectly via StatefulFunction
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey - I've left some high level feedback:
- The thread-local
_nan_storelist is never cleared and only grows over time; consider truncating it or resetting it after_raise_if_nan_detected(or exposing an explicit reset) to avoid unbounded memory growth in long-running processes. - In
_extract_user_source, filtering traceback lines with the hard-coded'/site-packages/'substring may fail on Windows paths or nonstandard install layouts; it would be more robust to detect third-party frames viasite/sysconfigpaths orpathlib-based checks instead of a literal string.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The thread-local `_nan_store` list is never cleared and only grows over time; consider truncating it or resetting it after `_raise_if_nan_detected` (or exposing an explicit reset) to avoid unbounded memory growth in long-running processes.
- In `_extract_user_source`, filtering traceback lines with the hard-coded `'/site-packages/'` substring may fail on Windows paths or nonstandard install layouts; it would be more robust to detect third-party frames via `site`/`sysconfig` paths or `pathlib`-based checks instead of a literal string.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
There was a problem hiding this comment.
Pull request overview
This PR refactors two key subsystems in brainstate: the NaN debugging pipeline and the StatefulFunction compilation cache. The NaN debugging is rewritten to perform on-device detection using an instrumented jaxpr interpreter (eliminating CPU fallback), while the StatefulFunction cache is unified from four separate caches into a single CacheKey → _CachedCompilation cache with added state shape validation.
Changes:
- Replaces the CPU-based NaN debugging with an on-device instrumented jaxpr interpreter that detects NaN/Inf introduction per-primitive, reports IDE-clickable source locations, and integrates with JAX's traceback filtering.
- Unifies
StatefulFunction's four separate bounded caches (jaxpr_cache,out_shapes_cache,jaxpr_out_tree_cache,state_trace_cache) into a single_compilation_cachekeyed by an immutableCacheKeyNamedTuple, with double-check locking and state shape/dtype validation. - Simplifies the gradient transform's
debug_nanintegration by removing unuseddepth/contextparameters and updating imports.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
brainstate/transform/_debug.py |
Complete rewrite of NaN debugging: on-device detection, thread-local store, source info extraction, instrumented jaxpr interpreter with nested primitive support |
brainstate/transform/_make_jaxpr.py |
Unified compilation cache with CacheKey/_CachedCompilation, state shape validation, IR optimization integration, improved _make_hashable |
brainstate/transform/_debug_test.py |
Replaced and expanded tests for the new NaN detection pipeline |
brainstate/transform/_make_jaxpr_test.py |
Updated tests for unified cache, CacheKey immutability, state shape validation, IR optimizations |
brainstate/transform/_grad_transform.py |
Removed unused debug_depth/debug_context parameters |
brainstate/transform/_mapping2.py |
Fixed BoundedCache import to use brainstate.util._cache |
brainstate/__init__.py |
Registered brainstate with JAX's traceback filter |
examples/011_debug_nan_gradient.py |
Removed exp_exprel example, uncommented all standard examples |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for eqn, meta in zip(jaxpr.eqns, eqn_meta): | ||
| invals = _get_invals(eqn, env) | ||
|
|
||
| # On-device: does the input already carry NaN? | ||
| input_has_nan = _has_nan_flag(invals) | ||
|
|
||
| # Execute the primitive (with recursive instrumentation for nested ones) | ||
| outvals = _execute_eqn(eqn, invals, phase, raise_in_callback) | ||
|
|
||
| # On-device: did this op introduce new NaN? | ||
| output_has_nan = _has_nan_flag(outvals) | ||
| nan_introduced = output_has_nan & ~input_has_nan |
There was a problem hiding this comment.
For nested high-level primitives (jit, cond, while, scan), NaN is detected both inside the recursively-instrumented inner jaxpr and at the outer level in _interpret_jaxpr_with_nan_check. This happens because _execute_eqn dispatches to _execute_jit_eqn etc. which instruments the inner jaxpr (and fires a callback if NaN is introduced inside), but the outer loop also checks nan_introduced = output_has_nan & ~input_has_nan for the containing equation — which will be True because the inner computation introduced NaN from clean inputs.
This leads to double-reporting of the same NaN: once for the inner primitive (e.g., log) and once for the outer primitive (e.g., pjit). Consider skipping the outer NaN check for expandable/nested primitives, since the inner instrumentation already handles them.
| cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True) | ||
| self._validate_state_shapes(cache_key) | ||
| return self.jaxpr_call_auto(*args, **kwargs) |
There was a problem hiding this comment.
The __call__ method computes cache_key via get_arg_cache_key(*args, **kwargs, compile_if_miss=True) and then immediately calls jaxpr_call_auto(*args, **kwargs), which internally computes the exact same cache key again (line 1132). This means every user-facing call redundantly computes the cache key twice. Consider passing the already-computed cache_key to jaxpr_call_auto (or extracting a shared internal method) to avoid the redundant work.
| def _raise_if_nan_detected(store_snapshot: int) -> None: | ||
| """ | ||
| After an instrumented run, check the thread-local store for new NaN records. | ||
|
|
||
| class DebugNan: | ||
| Uses ``source_info_util.user_context`` — the same mechanism as | ||
| ``State.raise_error_with_source_info`` — so that JAX's traceback filtering | ||
| shows the *user code* that introduced the NaN rather than library internals. | ||
| """ | ||
| JIT-compatible NaN/Inf debugging utility. | ||
| records = _nan_store_get() | ||
| new = records[store_snapshot:] | ||
| if not new: | ||
| return | ||
| msg, raw_src = new[0] | ||
| if raw_src is not None: | ||
| tb = getattr(raw_src, 'traceback', None) | ||
| name_stack = ( | ||
| source_info_util.current_name_stack() | ||
| + getattr(raw_src, 'name_stack', source_info_util.NameStack()) | ||
| ) | ||
| with source_info_util.user_context(tb, name_stack=name_stack): | ||
| raise RuntimeError(msg) | ||
| raise RuntimeError(msg) |
There was a problem hiding this comment.
The thread-local _nan_store.records list is only ever appended to (line 214) but never cleared or trimmed after records are consumed by _raise_if_nan_detected. In long-running applications or test suites, this list will grow without bound, leaking memory. After raising (or deciding not to raise), the consumed records should be removed, e.g. del records[store_snapshot:] at the end of _raise_if_nan_detected, or after each check() / check_if() call.
…line gradient examples
…itive that introduced NaN
Summary by Sourcery
Introduce a new on-device NaN/Inf debugging pipeline integrated with stateful JAX execution and simplify caching, while updating tests and examples accordingly.
New Features:
Enhancements:
Documentation:
Tests:
Chores: