Skip to content

Commit 02b9deb

Browse files
Merge pull request #1380 from ChrisRackauckas-Claude/cache-reactant-kernel
Cache compiled Reactant kernels across gradient calls
2 parents 809b943 + 5fe3191 commit 02b9deb

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

ext/SciMLSensitivityReactantExt.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ import SciMLSensitivity: get_paramjac_config, reactant_run_ad!, reactant_run_dua
88
ReactantVJP, ReactantLoaded, ReactantVJPConfig, ReactantDualTag,
99
get_cb_paramjac_config, reactant_run_cb_ad!
1010

11+
# Global cache for compiled Reactant kernels. Keyed by
12+
# (typeof(raw_f), iip, allow_scalar, length(y), typeof(p), length(p))
13+
# so that repeated gradient calls with the same ODE function skip recompilation.
14+
const _REACTANT_KERNEL_CACHE = Ref(Dict{Any, Any}())
15+
1116
# Helper: conditionally wrap Reactant.compile in @allowscalar
1217
function _reactant_compile(kernel, args, allow_scalar::Bool)
1318
if allow_scalar
@@ -133,22 +138,32 @@ end
133138
# =============================================================================
134139

135140
function _compile_float_kernel(raw_f, iip, vjp, y, p, t_val)
141+
key = (typeof(raw_f), iip, vjp.allow_scalar, length(y), typeof(p), length(p))
142+
cached = get(_REACTANT_KERNEL_CACHE[], key, nothing)
143+
cached !== nothing && return cached
136144
vjp_kernel = _make_vjp_kernel(raw_f, iip)
137145
dy_buf = Reactant.to_rarray(zero(y))
138146
u_ra = Reactant.to_rarray(zero(y))
139147
p_ra = Reactant.to_rarray(zero(p))
140148
t_ra = Reactant.to_rarray(t_val; track_numbers = true)
141149
λ_ra = Reactant.to_rarray(zero(y))
142-
return _reactant_compile(vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp.allow_scalar)
150+
compiled = _reactant_compile(vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp.allow_scalar)
151+
_REACTANT_KERNEL_CACHE[][key] = compiled
152+
return compiled
143153
end
144154

145155
function _compile_float_kernel_nullparams(raw_f, iip, vjp, y, t_val)
156+
key = (typeof(raw_f), iip, vjp.allow_scalar, length(y), :nullparams)
157+
cached = get(_REACTANT_KERNEL_CACHE[], key, nothing)
158+
cached !== nothing && return cached
146159
vjp_kernel = _make_vjp_kernel_nullparams(raw_f, iip)
147160
dy_buf = Reactant.to_rarray(zero(y))
148161
u_ra = Reactant.to_rarray(zero(y))
149162
t_ra = Reactant.to_rarray(t_val; track_numbers = true)
150163
λ_ra = Reactant.to_rarray(zero(y))
151-
return _reactant_compile(vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp.allow_scalar)
164+
compiled = _reactant_compile(vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp.allow_scalar)
165+
_REACTANT_KERNEL_CACHE[][key] = compiled
166+
return compiled
152167
end
153168

154169
# =============================================================================

0 commit comments

Comments
 (0)