@@ -182,24 +182,27 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
182
182
initialization:: IT
183
183
pn_observation_noise:: RT
184
184
covariance_factorization:: CF
185
+ autodiff:: AD
185
186
EK1 (;
186
187
order= 3 ,
187
188
prior:: PT = IWP (order),
188
189
diffusionmodel:: DT = DynamicDiffusion (),
189
190
smooth= true ,
190
191
initialization:: IT = TaylorModeInit (num_derivatives (prior)),
191
192
chunk_size= Val {0} (),
192
- autodiff= Val {true} (),
193
- diff_type= Val{:forward },
193
+ autodiff= AutoForwardDiff (),
194
+ diff_type= Val {:forward} () ,
194
195
standardtag= Val {true} (),
195
196
concrete_jac= nothing ,
196
197
pn_observation_noise:: RT = nothing ,
197
198
covariance_factorization:: CF = covariance_structure (EK1, prior, diffusionmodel),
198
199
) where {PT,DT,IT,RT,CF} = begin
199
200
ekargcheck (EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
201
+ AD_choice, chunk_size, diff_type =
202
+ OrdinaryDiffEqCore. _process_AD_choice (autodiff, chunk_size, diff_type)
200
203
new{
201
204
_unwrap_val (chunk_size),
202
- _unwrap_val (autodiff ),
205
+ typeof (AD_choice ),
203
206
diff_type,
204
207
_unwrap_val (standardtag),
205
208
_unwrap_val (concrete_jac),
@@ -215,6 +218,7 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
215
218
initialization,
216
219
pn_observation_noise,
217
220
covariance_factorization,
221
+ AD_choice
218
222
)
219
223
end
220
224
end
@@ -226,15 +230,16 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
226
230
initialization:: IT
227
231
pn_observation_noise:: RT
228
232
covariance_factorization:: CF
233
+ autodiff:: AD
229
234
DiagonalEK1 (;
230
235
order= 3 ,
231
236
prior:: PT = IWP (order),
232
237
diffusionmodel:: DT = DynamicDiffusion (),
233
238
smooth= true ,
234
239
initialization:: IT = TaylorModeInit (num_derivatives (prior)),
235
240
chunk_size= Val {0} (),
236
- autodiff= Val {true} (),
237
- diff_type= Val{:forward },
241
+ autodiff= AutoForwardDiff (),
242
+ diff_type= Val {:forward} () ,
238
243
standardtag= Val {true} (),
239
244
concrete_jac= nothing ,
240
245
pn_observation_noise:: RT = nothing ,
@@ -245,9 +250,11 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
245
250
),
246
251
) where {PT,DT,IT,RT,CF} = begin
247
252
ekargcheck (DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
253
+ AD_choice, chunk_size, diff_type =
254
+ OrdinaryDiffEqCore. _process_AD_choice (autodiff, chunk_size, diff_type)
248
255
new{
249
256
_unwrap_val (chunk_size),
250
- _unwrap_val (autodiff ),
257
+ typeof (AD_choice ),
251
258
diff_type,
252
259
_unwrap_val (standardtag),
253
260
_unwrap_val (concrete_jac),
@@ -263,6 +270,7 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
263
270
initialization,
264
271
pn_observation_noise,
265
272
covariance_factorization,
273
+ AD_choice
266
274
)
267
275
end
268
276
end
@@ -334,16 +342,17 @@ RosenbrockExpEK(; order=3, kwargs...) =
334
342
EK1 (; prior= IOUP (order, update_rate_parameter= true ), kwargs... )
335
343
336
344
function DiffEqBase. remake (thing:: EK1{CS,AD,DT,ST,CJ} ; kwargs... ) where {CS,AD,DT,ST,CJ}
345
+ if haskey (kwargs, :autodiff ) && kwargs[:autodiff ] isa AutoForwardDiff
346
+ chunk_size = OrdinaryDiffEqCore. _get_fwd_chunksize (kwargs[:autodiff ])
347
+ else
348
+ chunk_size = Val {CS} ()
349
+ end
350
+
337
351
T = SciMLBase. remaker_of (thing)
338
- T (;
339
- SciMLBase. struct_as_namedtuple (thing)... ,
340
- chunk_size= Val {CS} (),
341
- autodiff= Val {AD} (),
342
- standardtag= Val {ST} (),
352
+ T (; SciMLBase. struct_as_namedtuple (thing)... ,
353
+ chunk_size= chunk_size, autodiff= thing. autodiff, standardtag= Val {ST} (),
343
354
concrete_jac= CJ === nothing ? CJ : Val {CJ} (),
344
- diff_type= DT,
345
- kwargs... ,
346
- )
355
+ kwargs... )
347
356
end
348
357
349
358
function DiffEqBase. prepare_alg (
@@ -357,21 +366,25 @@ function DiffEqBase.prepare_alg(
357
366
# use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
358
367
# is a requirement.
359
368
360
- if (isbitstype (T) && sizeof (T) > 24 ) || (
361
- prob. f isa ODEFunction &&
362
- prob. f. f isa FunctionWrappersWrappers. FunctionWrappersWrapper
363
- )
364
- return remake (alg, chunk_size= Val {1} ())
365
- end
369
+ prepped_AD = OrdinaryDiffEqDifferentiation. prepare_ADType (OrdinaryDiffEqDifferentiation. alg_autodiff (alg), prob, u0, p, OrdinaryDiffEqDifferentiation. standardtag (alg))
370
+
371
+ sparse_prepped_AD = OrdinaryDiffEqDifferentiation. prepare_user_sparsity (prepped_AD, prob)
366
372
367
373
L = StaticArrayInterface. known_length (typeof (u0))
368
374
@assert L === nothing " ProbNumDiffEq.jl does not support StaticArrays yet."
369
375
370
- x = if prob. f. colorvec === nothing
371
- length (u0)
376
+ if (
377
+ (
378
+ (eltype (u0) <: Complex ) ||
379
+ (! (prob. f isa DAEFunction) && prob. f. mass_matrix isa MatrixOperator)
380
+ ) && sparse_prepped_AD isa AutoSparse
381
+ )
382
+ @warn " Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
383
+ autodiff = ADTypes. dense_ad (sparse_prepped_AD)
372
384
else
373
- maximum (prob . f . colorvec)
385
+ autodiff = sparse_prepped_AD
374
386
end
375
- cs = ForwardDiff. pickchunksize (x)
376
- return remake (alg, chunk_size= Val {cs} ())
387
+
388
+
389
+ return remake (alg, autodiff = autodiff)
377
390
end
0 commit comments