@@ -182,24 +182,27 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
182182 initialization:: IT
183183 pn_observation_noise:: RT
184184 covariance_factorization:: CF
185+ autodiff:: AD
185186 EK1 (;
186187 order= 3 ,
187188 prior:: PT = IWP (order),
188189 diffusionmodel:: DT = DynamicDiffusion (),
189190 smooth= true ,
190191 initialization:: IT = TaylorModeInit (num_derivatives (prior)),
191192 chunk_size= Val {0} (),
192- autodiff= Val {true} (),
193- diff_type= Val{:forward },
193+ autodiff= AutoForwardDiff (),
194+ diff_type= Val {:forward} () ,
194195 standardtag= Val {true} (),
195196 concrete_jac= nothing ,
196197 pn_observation_noise:: RT = nothing ,
197198 covariance_factorization:: CF = covariance_structure (EK1, prior, diffusionmodel),
198199 ) where {PT,DT,IT,RT,CF} = begin
199200 ekargcheck (EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
201+ AD_choice, chunk_size, diff_type =
202+ OrdinaryDiffEqCore. _process_AD_choice (autodiff, chunk_size, diff_type)
200203 new{
201204 _unwrap_val (chunk_size),
202- _unwrap_val (autodiff ),
205+ typeof (AD_choice ),
203206 diff_type,
204207 _unwrap_val (standardtag),
205208 _unwrap_val (concrete_jac),
@@ -215,6 +218,7 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
215218 initialization,
216219 pn_observation_noise,
217220 covariance_factorization,
221+ AD_choice
218222 )
219223 end
220224end
@@ -226,15 +230,16 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
226230 initialization:: IT
227231 pn_observation_noise:: RT
228232 covariance_factorization:: CF
233+ autodiff:: AD
229234 DiagonalEK1 (;
230235 order= 3 ,
231236 prior:: PT = IWP (order),
232237 diffusionmodel:: DT = DynamicDiffusion (),
233238 smooth= true ,
234239 initialization:: IT = TaylorModeInit (num_derivatives (prior)),
235240 chunk_size= Val {0} (),
236- autodiff= Val {true} (),
237- diff_type= Val{:forward },
241+ autodiff= AutoForwardDiff (),
242+ diff_type= Val {:forward} () ,
238243 standardtag= Val {true} (),
239244 concrete_jac= nothing ,
240245 pn_observation_noise:: RT = nothing ,
@@ -245,9 +250,11 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
245250 ),
246251 ) where {PT,DT,IT,RT,CF} = begin
247252 ekargcheck (DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
253+ AD_choice, chunk_size, diff_type =
254+ OrdinaryDiffEqCore. _process_AD_choice (autodiff, chunk_size, diff_type)
248255 new{
249256 _unwrap_val (chunk_size),
250- _unwrap_val (autodiff ),
257+ typeof (AD_choice ),
251258 diff_type,
252259 _unwrap_val (standardtag),
253260 _unwrap_val (concrete_jac),
@@ -263,6 +270,7 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
263270 initialization,
264271 pn_observation_noise,
265272 covariance_factorization,
273+ AD_choice
266274 )
267275 end
268276end
@@ -334,16 +342,17 @@ RosenbrockExpEK(; order=3, kwargs...) =
334342 EK1 (; prior= IOUP (order, update_rate_parameter= true ), kwargs... )
335343
336344function DiffEqBase. remake (thing:: EK1{CS,AD,DT,ST,CJ} ; kwargs... ) where {CS,AD,DT,ST,CJ}
345+ if haskey (kwargs, :autodiff ) && kwargs[:autodiff ] isa AutoForwardDiff
346+ chunk_size = OrdinaryDiffEqCore. _get_fwd_chunksize (kwargs[:autodiff ])
347+ else
348+ chunk_size = Val {CS} ()
349+ end
350+
337351 T = SciMLBase. remaker_of (thing)
338- T (;
339- SciMLBase. struct_as_namedtuple (thing)... ,
340- chunk_size= Val {CS} (),
341- autodiff= Val {AD} (),
342- standardtag= Val {ST} (),
352+ T (; SciMLBase. struct_as_namedtuple (thing)... ,
353+ chunk_size= chunk_size, autodiff= thing. autodiff, standardtag= Val {ST} (),
343354 concrete_jac= CJ === nothing ? CJ : Val {CJ} (),
344- diff_type= DT,
345- kwargs... ,
346- )
355+ kwargs... )
347356end
348357
349358function DiffEqBase. prepare_alg (
@@ -357,21 +366,25 @@ function DiffEqBase.prepare_alg(
357366 # use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
358367 # is a requirement.
359368
360- if (isbitstype (T) && sizeof (T) > 24 ) || (
361- prob. f isa ODEFunction &&
362- prob. f. f isa FunctionWrappersWrappers. FunctionWrappersWrapper
363- )
364- return remake (alg, chunk_size= Val {1} ())
365- end
369+ prepped_AD = OrdinaryDiffEqDifferentiation. prepare_ADType (OrdinaryDiffEqDifferentiation. alg_autodiff (alg), prob, u0, p, OrdinaryDiffEqDifferentiation. standardtag (alg))
370+
371+ sparse_prepped_AD = OrdinaryDiffEqDifferentiation. prepare_user_sparsity (prepped_AD, prob)
366372
367373 L = StaticArrayInterface. known_length (typeof (u0))
368374 @assert L === nothing " ProbNumDiffEq.jl does not support StaticArrays yet."
369375
370- x = if prob. f. colorvec === nothing
371- length (u0)
376+ if (
377+ (
378+ (eltype (u0) <: Complex ) ||
379+ (! (prob. f isa DAEFunction) && prob. f. mass_matrix isa MatrixOperator)
380+ ) && sparse_prepped_AD isa AutoSparse
381+ )
382+ @warn " Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
383+ autodiff = ADTypes. dense_ad (sparse_prepped_AD)
372384 else
373- maximum (prob . f . colorvec)
385+ autodiff = sparse_prepped_AD
374386 end
375- cs = ForwardDiff. pickchunksize (x)
376- return remake (alg, chunk_size= Val {cs} ())
387+
388+
389+ return remake (alg, autodiff = autodiff)
377390end
0 commit comments