Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Nathanael Bosch"]
version = "0.16.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down Expand Up @@ -48,6 +49,7 @@ DiffEqDevToolsExt = "DiffEqDevTools"
RecipesBaseExt = "RecipesBase"

[compat]
ADTypes = "1.14.0"
ArrayAllocators = "0.3"
BlockArrays = "1"
DiffEqBase = "6.122"
Expand Down
1 change: 1 addition & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using FiniteHorizonGramians
using FillArrays
using MatrixEquations
using DiffEqCallbacks
using ADTypes

# @reexport using GaussianDistributions

Expand Down
4 changes: 2 additions & 2 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ OrdinaryDiffEqDifferentiation.concrete_jac(::AbstractEK) = nothing
OrdinaryDiffEqCore.isfsal(::AbstractEK) = false

for ALG in [:EK1, :DiagonalEK1]
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} =
Val{AD}()
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(alg::$ALG{CS,AD}) where {CS,AD} =
alg.autodiff
@eval OrdinaryDiffEqDifferentiation.alg_difftype(
::$ALG{CS,AD,DiffType},
) where {CS,AD,DiffType} =
Expand Down
63 changes: 38 additions & 25 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,27 @@
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
autodiff=AutoForwardDiff(),
diff_type=Val{:forward}(),
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
covariance_factorization::CF=covariance_structure(EK1, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -215,6 +218,7 @@
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand All @@ -226,15 +230,16 @@
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
DiagonalEK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
autodiff=AutoForwardDiff(),
diff_type=Val{:forward}(),
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
Expand All @@ -245,9 +250,11 @@
),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -263,6 +270,7 @@
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand Down Expand Up @@ -334,16 +342,17 @@
EK1(; prior=IOUP(order, update_rate_parameter=true), kwargs...)

function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,DT,ST,CJ}
if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff
chunk_size = OrdinaryDiffEqCore._get_fwd_chunksize(kwargs[:autodiff])
else
chunk_size = Val{CS}()
end

T = SciMLBase.remaker_of(thing)
T(;
SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=Val{CS}(),
autodiff=Val{AD}(),
standardtag=Val{ST}(),
T(; SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=chunk_size, autodiff=thing.autodiff, standardtag=Val{ST}(),
concrete_jac=CJ === nothing ? CJ : Val{CJ}(),
diff_type=DT,
kwargs...,
)
kwargs...)
end

function DiffEqBase.prepare_alg(
Expand All @@ -357,21 +366,25 @@
# use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
# is a requirement.

if (isbitstype(T) && sizeof(T) > 24) || (
prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
)
return remake(alg, chunk_size=Val{1}())
end
prepped_AD = OrdinaryDiffEqDifferentiation.prepare_ADType(OrdinaryDiffEqDifferentiation.alg_autodiff(alg), prob, u0, p, OrdinaryDiffEqDifferentiation.standardtag(alg))

sparse_prepped_AD = OrdinaryDiffEqDifferentiation.prepare_user_sparsity(prepped_AD, prob)

L = StaticArrayInterface.known_length(typeof(u0))
@assert L === nothing "ProbNumDiffEq.jl does not support StaticArrays yet."

x = if prob.f.colorvec === nothing
length(u0)
if (
(
(eltype(u0) <: Complex) ||
(!(prob.f isa DAEFunction) && prob.f.mass_matrix isa MatrixOperator)
) && sparse_prepped_AD isa AutoSparse
)
@warn "Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
autodiff = ADTypes.dense_ad(sparse_prepped_AD)

Check warning on line 383 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L382-L383

Added lines #L382 - L383 were not covered by tests
else
maximum(prob.f.colorvec)
autodiff = sparse_prepped_AD
end
cs = ForwardDiff.pickchunksize(x)
return remake(alg, chunk_size=Val{cs}())


return remake(alg, autodiff = autodiff)
end
13 changes: 13 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ ProbODESolution{T,N}(
pnstats, prob, alg, interp, cache, dense, tslocation, stats, retcode,
)


function SciMLBase.constructorof(
::Type{
ProbNumDiffEq.ProbODESolution{T,N,uType,puType,uType2,DType,tType,rateType,xType,
diffType,bkType,PN,P,A,IType,
CType,DE}}
) where {T,N,uType,puType,uType2,DType,tType,rateType,xType,
diffType,bkType,PN,P,A,IType,
CType,DE}
ProbODESolution{T,N}
end


function DiffEqBase.solution_new_retcode(sol::ProbODESolution{T,N}, retcode) where {T,N}
return ProbODESolution{T,N}(
sol.u, sol.pu, sol.u_analytic, sol.errors, sol.t, sol.k, sol.x_filt, sol.x_smooth,
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProbNumDiffEq = "bf3e78b0-7d74-48a5-b855-9609533b56a5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"


[compat]
Aqua = "0.8.2"
28 changes: 18 additions & 10 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using LinearAlgebra
using FiniteDifferences
using ForwardDiff
import SciMLStructures
# using ReverseDiff
# using Zygote

Expand All @@ -17,15 +18,20 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
_prob.tspan,
jac=true,
)
prob = remake(prob, p=collect(_prob.p))
#prob = remake(prob, p=collect(_prob.p))
ps = ModelingToolkit.parameter_values(prob)
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, [1.0, 2.0, 3.0, 4.0])
prob = remake(prob, p=ps)

function param_to_loss(p)
ps = ModelingToolkit.parameter_values(prob)
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, p)
sol = solve(
remake(prob, p=p),
remake(prob, p=ps),
ALG(order=3, smooth=false),
sensealg=SensitivityADPassThrough(),
abstol=1e-3,
reltol=1e-2,
abstol=1e-6,
reltol=1e-5,
save_everystep=false,
dense=false,
)
Expand All @@ -36,22 +42,24 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
remake(prob, u0=u0),
ALG(order=3, smooth=false),
sensealg=SensitivityADPassThrough(),
abstol=1e-3,
reltol=1e-2,
abstol=1e-6,
reltol=1e-5,
save_everystep=false,
dense=false,
)
return norm(sol.u[end]) # Dummy loss
end

# dldp = FiniteDiff.finite_difference_gradient(param_to_loss, prob.p)
# dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
p, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), prob.p)

#dldp = FiniteDiff.finite_difference_gradient(param_to_loss, p)
#dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
# For some reason FiniteDiff.jl is not working anymore so we use FiniteDifferences.jl:
dldp = grad(central_fdm(5, 1), param_to_loss, prob.p)[1]
dldp = grad(central_fdm(5, 1), param_to_loss, p)[1]
dldu0 = grad(central_fdm(5, 1), startval_to_loss, prob.u0)[1]

@testset "ForwardDiff.jl" begin
@test ForwardDiff.gradient(param_to_loss, prob.p) ≈ dldp rtol = 1e-2
@test ForwardDiff.gradient(param_to_loss, p) ≈ dldp rtol = 1e-2
@test ForwardDiff.gradient(startval_to_loss, prob.u0) ≈ dldu0 rtol = 5e-2
end

Expand Down
Loading