Skip to content

Commit eb9757a

Browse files
Update to use ADTypes for specifying AD backend (#338)
* add ADTypes * use ADTypes instead of Bool to specify AD backend * Use ADTypes for autodiff kwarg Co-authored-by: Christopher Rackauckas <[email protected]> * add constructorof for ProbODESolution so accessors and setfield works * make autodiff tests run * import SciMLStructures * add to seperate test project * change test tolerance --------- Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent f19d9e7 commit eb9757a

File tree

7 files changed

+76
-37
lines changed

7 files changed

+76
-37
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Nathanael Bosch"]
44
version = "0.16.2"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
89
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
910
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
@@ -48,6 +49,7 @@ DiffEqDevToolsExt = "DiffEqDevTools"
4849
RecipesBaseExt = "RecipesBase"
4950

5051
[compat]
52+
ADTypes = "1.14.0"
5153
ArrayAllocators = "0.3"
5254
BlockArrays = "1"
5355
DiffEqBase = "6.122"

src/ProbNumDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ using FiniteHorizonGramians
3737
using FillArrays
3838
using MatrixEquations
3939
using DiffEqCallbacks
40+
using ADTypes
4041

4142
# @reexport using GaussianDistributions
4243

src/alg_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ OrdinaryDiffEqDifferentiation.concrete_jac(::AbstractEK) = nothing
1212
OrdinaryDiffEqCore.isfsal(::AbstractEK) = false
1313

1414
for ALG in [:EK1, :DiagonalEK1]
15-
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} =
16-
Val{AD}()
15+
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(alg::$ALG{CS,AD}) where {CS,AD} =
16+
alg.autodiff
1717
@eval OrdinaryDiffEqDifferentiation.alg_difftype(
1818
::$ALG{CS,AD,DiffType},
1919
) where {CS,AD,DiffType} =

src/algorithms.jl

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,27 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
182182
initialization::IT
183183
pn_observation_noise::RT
184184
covariance_factorization::CF
185+
autodiff::AD
185186
EK1(;
186187
order=3,
187188
prior::PT=IWP(order),
188189
diffusionmodel::DT=DynamicDiffusion(),
189190
smooth=true,
190191
initialization::IT=TaylorModeInit(num_derivatives(prior)),
191192
chunk_size=Val{0}(),
192-
autodiff=Val{true}(),
193-
diff_type=Val{:forward},
193+
autodiff=AutoForwardDiff(),
194+
diff_type=Val{:forward}(),
194195
standardtag=Val{true}(),
195196
concrete_jac=nothing,
196197
pn_observation_noise::RT=nothing,
197198
covariance_factorization::CF=covariance_structure(EK1, prior, diffusionmodel),
198199
) where {PT,DT,IT,RT,CF} = begin
199200
ekargcheck(EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
201+
AD_choice, chunk_size, diff_type =
202+
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
200203
new{
201204
_unwrap_val(chunk_size),
202-
_unwrap_val(autodiff),
205+
typeof(AD_choice),
203206
diff_type,
204207
_unwrap_val(standardtag),
205208
_unwrap_val(concrete_jac),
@@ -215,6 +218,7 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
215218
initialization,
216219
pn_observation_noise,
217220
covariance_factorization,
221+
AD_choice
218222
)
219223
end
220224
end
@@ -226,15 +230,16 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
226230
initialization::IT
227231
pn_observation_noise::RT
228232
covariance_factorization::CF
233+
autodiff::AD
229234
DiagonalEK1(;
230235
order=3,
231236
prior::PT=IWP(order),
232237
diffusionmodel::DT=DynamicDiffusion(),
233238
smooth=true,
234239
initialization::IT=TaylorModeInit(num_derivatives(prior)),
235240
chunk_size=Val{0}(),
236-
autodiff=Val{true}(),
237-
diff_type=Val{:forward},
241+
autodiff=AutoForwardDiff(),
242+
diff_type=Val{:forward}(),
238243
standardtag=Val{true}(),
239244
concrete_jac=nothing,
240245
pn_observation_noise::RT=nothing,
@@ -245,9 +250,11 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
245250
),
246251
) where {PT,DT,IT,RT,CF} = begin
247252
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
253+
AD_choice, chunk_size, diff_type =
254+
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
248255
new{
249256
_unwrap_val(chunk_size),
250-
_unwrap_val(autodiff),
257+
typeof(AD_choice),
251258
diff_type,
252259
_unwrap_val(standardtag),
253260
_unwrap_val(concrete_jac),
@@ -263,6 +270,7 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
263270
initialization,
264271
pn_observation_noise,
265272
covariance_factorization,
273+
AD_choice
266274
)
267275
end
268276
end
@@ -334,16 +342,17 @@ RosenbrockExpEK(; order=3, kwargs...) =
334342
EK1(; prior=IOUP(order, update_rate_parameter=true), kwargs...)
335343

336344
function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,DT,ST,CJ}
345+
if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff
346+
chunk_size = OrdinaryDiffEqCore._get_fwd_chunksize(kwargs[:autodiff])
347+
else
348+
chunk_size = Val{CS}()
349+
end
350+
337351
T = SciMLBase.remaker_of(thing)
338-
T(;
339-
SciMLBase.struct_as_namedtuple(thing)...,
340-
chunk_size=Val{CS}(),
341-
autodiff=Val{AD}(),
342-
standardtag=Val{ST}(),
352+
T(; SciMLBase.struct_as_namedtuple(thing)...,
353+
chunk_size=chunk_size, autodiff=thing.autodiff, standardtag=Val{ST}(),
343354
concrete_jac=CJ === nothing ? CJ : Val{CJ}(),
344-
diff_type=DT,
345-
kwargs...,
346-
)
355+
kwargs...)
347356
end
348357

349358
function DiffEqBase.prepare_alg(
@@ -357,21 +366,25 @@ function DiffEqBase.prepare_alg(
357366
# use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
358367
# is a requirement.
359368

360-
if (isbitstype(T) && sizeof(T) > 24) || (
361-
prob.f isa ODEFunction &&
362-
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
363-
)
364-
return remake(alg, chunk_size=Val{1}())
365-
end
369+
prepped_AD = OrdinaryDiffEqDifferentiation.prepare_ADType(OrdinaryDiffEqDifferentiation.alg_autodiff(alg), prob, u0, p, OrdinaryDiffEqDifferentiation.standardtag(alg))
370+
371+
sparse_prepped_AD = OrdinaryDiffEqDifferentiation.prepare_user_sparsity(prepped_AD, prob)
366372

367373
L = StaticArrayInterface.known_length(typeof(u0))
368374
@assert L === nothing "ProbNumDiffEq.jl does not support StaticArrays yet."
369375

370-
x = if prob.f.colorvec === nothing
371-
length(u0)
376+
if (
377+
(
378+
(eltype(u0) <: Complex) ||
379+
(!(prob.f isa DAEFunction) && prob.f.mass_matrix isa MatrixOperator)
380+
) && sparse_prepped_AD isa AutoSparse
381+
)
382+
@warn "Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
383+
autodiff = ADTypes.dense_ad(sparse_prepped_AD)
372384
else
373-
maximum(prob.f.colorvec)
385+
autodiff = sparse_prepped_AD
374386
end
375-
cs = ForwardDiff.pickchunksize(x)
376-
return remake(alg, chunk_size=Val{cs}())
387+
388+
389+
return remake(alg, autodiff = autodiff)
377390
end

src/solution.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ ProbODESolution{T,N}(
6666
pnstats, prob, alg, interp, cache, dense, tslocation, stats, retcode,
6767
)
6868

69+
70+
function SciMLBase.constructorof(
71+
::Type{
72+
ProbNumDiffEq.ProbODESolution{T,N,uType,puType,uType2,DType,tType,rateType,xType,
73+
diffType,bkType,PN,P,A,IType,
74+
CType,DE}}
75+
) where {T,N,uType,puType,uType2,DType,tType,rateType,xType,
76+
diffType,bkType,PN,P,A,IType,
77+
CType,DE}
78+
ProbODESolution{T,N}
79+
end
80+
81+
6982
function DiffEqBase.solution_new_retcode(sol::ProbODESolution{T,N}, retcode) where {T,N}
7083
return ProbODESolution{T,N}(
7184
sol.u, sol.pu, sol.u_analytic, sol.errors, sol.t, sol.k, sol.x_filt, sol.x_smooth,

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1919
ProbNumDiffEq = "bf3e78b0-7d74-48a5-b855-9609533b56a5"
2020
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2121
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
22+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
2223
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
2324
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2425
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2526
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
2627
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2728

29+
2830
[compat]
2931
Aqua = "0.8.2"

test/autodiff.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Test
44
using LinearAlgebra
55
using FiniteDifferences
66
using ForwardDiff
7+
import SciMLStructures
78
# using ReverseDiff
89
# using Zygote
910

@@ -17,15 +18,20 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
1718
_prob.tspan,
1819
jac=true,
1920
)
20-
prob = remake(prob, p=collect(_prob.p))
21+
#prob = remake(prob, p=collect(_prob.p))
22+
ps = ModelingToolkit.parameter_values(prob)
23+
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, [1.0, 2.0, 3.0, 4.0])
24+
prob = remake(prob, p=ps)
2125

2226
function param_to_loss(p)
27+
ps = ModelingToolkit.parameter_values(prob)
28+
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, p)
2329
sol = solve(
24-
remake(prob, p=p),
30+
remake(prob, p=ps),
2531
ALG(order=3, smooth=false),
2632
sensealg=SensitivityADPassThrough(),
27-
abstol=1e-3,
28-
reltol=1e-2,
33+
abstol=1e-6,
34+
reltol=1e-5,
2935
save_everystep=false,
3036
dense=false,
3137
)
@@ -36,22 +42,24 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
3642
remake(prob, u0=u0),
3743
ALG(order=3, smooth=false),
3844
sensealg=SensitivityADPassThrough(),
39-
abstol=1e-3,
40-
reltol=1e-2,
45+
abstol=1e-6,
46+
reltol=1e-5,
4147
save_everystep=false,
4248
dense=false,
4349
)
4450
return norm(sol.u[end]) # Dummy loss
4551
end
4652

47-
# dldp = FiniteDiff.finite_difference_gradient(param_to_loss, prob.p)
48-
# dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
53+
p, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), prob.p)
54+
55+
#dldp = FiniteDiff.finite_difference_gradient(param_to_loss, p)
56+
#dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
4957
# For some reason FiniteDiff.jl is not working anymore so we use FiniteDifferences.jl:
50-
dldp = grad(central_fdm(5, 1), param_to_loss, prob.p)[1]
58+
dldp = grad(central_fdm(5, 1), param_to_loss, p)[1]
5159
dldu0 = grad(central_fdm(5, 1), startval_to_loss, prob.u0)[1]
5260

5361
@testset "ForwardDiff.jl" begin
54-
@test ForwardDiff.gradient(param_to_loss, prob.p) dldp rtol = 1e-2
62+
@test ForwardDiff.gradient(param_to_loss, p) dldp rtol = 1e-2
5563
@test ForwardDiff.gradient(startval_to_loss, prob.u0) dldu0 rtol = 5e-2
5664
end
5765

0 commit comments

Comments
 (0)