Skip to content

Commit 5fe3191

Browse files
Cache compiled Reactant kernels to avoid recompilation on repeated gradient calls
ReactantVJP recompiles the XLA kernel on every Zygote.gradient call because get_paramjac_config rebuilds the adjoint cache each time. This adds a global Dict cache keyed by (typeof(raw_f), iip, allow_scalar, array sizes/types) so that repeated calls with the same ODE function return the cached kernel instantly (~0.1ms) instead of recompiling (~1.5s). Benchmark on a chromatography ODE model: - 2nd gradient call: 21.8s → 17.4s (~20% faster) - Reactant.compile: 1.55s → 0.000076s (Dict lookup) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 809b943 commit 5fe3191

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

ext/SciMLSensitivityReactantExt.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ import SciMLSensitivity: get_paramjac_config, reactant_run_ad!, reactant_run_dua
88
ReactantVJP, ReactantLoaded, ReactantVJPConfig, ReactantDualTag,
99
get_cb_paramjac_config, reactant_run_cb_ad!
1010

11+
# Global cache for compiled Reactant kernels. Keyed by
12+
# (typeof(raw_f), iip, allow_scalar, length(y), typeof(p), length(p))
13+
# so that repeated gradient calls with the same ODE function skip recompilation.
14+
const _REACTANT_KERNEL_CACHE = Ref(Dict{Any, Any}())
15+
1116
# Helper: conditionally wrap Reactant.compile in @allowscalar
1217
function _reactant_compile(kernel, args, allow_scalar::Bool)
1318
if allow_scalar
@@ -133,22 +138,32 @@ end
133138
# =============================================================================
134139

135140
function _compile_float_kernel(raw_f, iip, vjp, y, p, t_val)
141+
key = (typeof(raw_f), iip, vjp.allow_scalar, length(y), typeof(p), length(p))
142+
cached = get(_REACTANT_KERNEL_CACHE[], key, nothing)
143+
cached !== nothing && return cached
136144
vjp_kernel = _make_vjp_kernel(raw_f, iip)
137145
dy_buf = Reactant.to_rarray(zero(y))
138146
u_ra = Reactant.to_rarray(zero(y))
139147
p_ra = Reactant.to_rarray(zero(p))
140148
t_ra = Reactant.to_rarray(t_val; track_numbers = true)
141149
λ_ra = Reactant.to_rarray(zero(y))
142-
return _reactant_compile(vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp.allow_scalar)
150+
compiled = _reactant_compile(vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp.allow_scalar)
151+
_REACTANT_KERNEL_CACHE[][key] = compiled
152+
return compiled
143153
end
144154

145155
function _compile_float_kernel_nullparams(raw_f, iip, vjp, y, t_val)
156+
key = (typeof(raw_f), iip, vjp.allow_scalar, length(y), :nullparams)
157+
cached = get(_REACTANT_KERNEL_CACHE[], key, nothing)
158+
cached !== nothing && return cached
146159
vjp_kernel = _make_vjp_kernel_nullparams(raw_f, iip)
147160
dy_buf = Reactant.to_rarray(zero(y))
148161
u_ra = Reactant.to_rarray(zero(y))
149162
t_ra = Reactant.to_rarray(t_val; track_numbers = true)
150163
λ_ra = Reactant.to_rarray(zero(y))
151-
return _reactant_compile(vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp.allow_scalar)
164+
compiled = _reactant_compile(vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp.allow_scalar)
165+
_REACTANT_KERNEL_CACHE[][key] = compiled
166+
return compiled
152167
end
153168

154169
# =============================================================================

0 commit comments

Comments
 (0)