You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Cache compiled Reactant kernels to avoid recompilation on repeated gradient calls
ReactantVJP recompiles the XLA kernel on every Zygote.gradient call because
get_paramjac_config rebuilds the adjoint cache each time. This adds a global
Dict cache keyed by (typeof(raw_f), iip, allow_scalar, array sizes/types)
so that repeated calls with the same ODE function return the cached kernel
instantly (~0.1ms) instead of recompiling (~1.5s).
Benchmark on a chromatography ODE model:
- 2nd gradient call: 21.8s → 17.4s (~20% faster)
- Reactant.compile: 1.55s → 0.000076s (Dict lookup)
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
0 commit comments