@@ -8,6 +8,11 @@ import SciMLSensitivity: get_paramjac_config, reactant_run_ad!, reactant_run_dua
88 ReactantVJP, ReactantLoaded, ReactantVJPConfig, ReactantDualTag,
99 get_cb_paramjac_config, reactant_run_cb_ad!
1010
11+ # Global cache for compiled Reactant kernels. Keyed by
12+ # (typeof(raw_f), iip, allow_scalar, length(y), typeof(p), length(p))
13+ # so that repeated gradient calls with the same ODE function skip recompilation.
14+ const _REACTANT_KERNEL_CACHE = Ref (Dict {Any, Any} ())
15+
1116# Helper: conditionally wrap Reactant.compile in @allowscalar
1217function _reactant_compile (kernel, args, allow_scalar:: Bool )
1318 if allow_scalar
@@ -133,22 +138,32 @@ end
133138# =============================================================================
134139
135140function _compile_float_kernel (raw_f, iip, vjp, y, p, t_val)
141+ key = (typeof (raw_f), iip, vjp. allow_scalar, length (y), typeof (p), length (p))
142+ cached = get (_REACTANT_KERNEL_CACHE[], key, nothing )
143+ cached != = nothing && return cached
136144 vjp_kernel = _make_vjp_kernel (raw_f, iip)
137145 dy_buf = Reactant. to_rarray (zero (y))
138146 u_ra = Reactant. to_rarray (zero (y))
139147 p_ra = Reactant. to_rarray (zero (p))
140148 t_ra = Reactant. to_rarray (t_val; track_numbers = true )
141149 λ_ra = Reactant. to_rarray (zero (y))
142- return _reactant_compile (vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp. allow_scalar)
150+ compiled = _reactant_compile (vjp_kernel, (dy_buf, u_ra, p_ra, t_ra, λ_ra), vjp. allow_scalar)
151+ _REACTANT_KERNEL_CACHE[][key] = compiled
152+ return compiled
143153end
144154
145155function _compile_float_kernel_nullparams (raw_f, iip, vjp, y, t_val)
156+ key = (typeof (raw_f), iip, vjp. allow_scalar, length (y), :nullparams )
157+ cached = get (_REACTANT_KERNEL_CACHE[], key, nothing )
158+ cached != = nothing && return cached
146159 vjp_kernel = _make_vjp_kernel_nullparams (raw_f, iip)
147160 dy_buf = Reactant. to_rarray (zero (y))
148161 u_ra = Reactant. to_rarray (zero (y))
149162 t_ra = Reactant. to_rarray (t_val; track_numbers = true )
150163 λ_ra = Reactant. to_rarray (zero (y))
151- return _reactant_compile (vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp. allow_scalar)
164+ compiled = _reactant_compile (vjp_kernel, (dy_buf, u_ra, t_ra, λ_ra), vjp. allow_scalar)
165+ _REACTANT_KERNEL_CACHE[][key] = compiled
166+ return compiled
152167end
153168
154169# =============================================================================
0 commit comments