Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Nathanael Bosch"]
version = "0.16.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down Expand Up @@ -48,6 +49,7 @@ DiffEqDevToolsExt = "DiffEqDevTools"
RecipesBaseExt = "RecipesBase"

[compat]
ADTypes = "1.14.0"
ArrayAllocators = "0.3"
BlockArrays = "1"
DiffEqBase = "6.122"
Expand Down
1 change: 1 addition & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using FiniteHorizonGramians
using FillArrays
using MatrixEquations
using DiffEqCallbacks
using ADTypes

# @reexport using GaussianDistributions

Expand Down
4 changes: 2 additions & 2 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ OrdinaryDiffEqDifferentiation.concrete_jac(::AbstractEK) = nothing
OrdinaryDiffEqCore.isfsal(::AbstractEK) = false

for ALG in [:EK1, :DiagonalEK1]
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} =
Val{AD}()
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(alg::$ALG{CS,AD}) where {CS,AD} =
alg.autodiff
@eval OrdinaryDiffEqDifferentiation.alg_difftype(
::$ALG{CS,AD,DiffType},
) where {CS,AD,DiffType} =
Expand Down
59 changes: 36 additions & 23 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,27 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
autodiff=AutoForwardDiff(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
covariance_factorization::CF=covariance_structure(EK1, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -215,6 +218,7 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand All @@ -226,14 +230,15 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
DiagonalEK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
autodiff=AutoForwardDiff(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
Expand All @@ -245,9 +250,11 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -263,6 +270,7 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand Down Expand Up @@ -334,16 +342,17 @@ RosenbrockExpEK(; order=3, kwargs...) =
EK1(; prior=IOUP(order, update_rate_parameter=true), kwargs...)

function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,DT,ST,CJ}
if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff
chunk_size = OrdinaryDiffEqCore._get_fwd_chunksize(kwargs[:autodiff])
else
chunk_size = Val{CS}()
end

T = SciMLBase.remaker_of(thing)
T(;
SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=Val{CS}(),
autodiff=Val{AD}(),
standardtag=Val{ST}(),
T(; SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=chunk_size, autodiff=thing.autodiff, standardtag=Val{ST}(),
concrete_jac=CJ === nothing ? CJ : Val{CJ}(),
diff_type=DT,
kwargs...,
)
kwargs...)
end

function DiffEqBase.prepare_alg(
Expand All @@ -357,21 +366,25 @@ function DiffEqBase.prepare_alg(
# use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
# is a requirement.

if (isbitstype(T) && sizeof(T) > 24) || (
prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
)
return remake(alg, chunk_size=Val{1}())
end
prepped_AD = OrdinaryDiffEqDifferentiation.prepare_ADType(OrdinaryDiffEqDifferentiation.alg_autodiff(alg), prob, u0, p, OrdinaryDiffEqDifferentiation.standardtag(alg))

sparse_prepped_AD = OrdinaryDiffEqDifferentiation.prepare_user_sparsity(prepped_AD, prob)

L = StaticArrayInterface.known_length(typeof(u0))
@assert L === nothing "ProbNumDiffEq.jl does not support StaticArrays yet."

x = if prob.f.colorvec === nothing
length(u0)
if (
(
(eltype(u0) <: Complex) ||
(!(prob.f isa DAEFunction) && prob.f.mass_matrix isa MatrixOperator)
) && sparse_prepped_AD isa AutoSparse
)
@warn "Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
autodiff = ADTypes.dense_ad(sparse_prepped_AD)
else
maximum(prob.f.colorvec)
autodiff = sparse_prepped_AD
end
cs = ForwardDiff.pickchunksize(x)
return remake(alg, chunk_size=Val{cs}())


return remake(alg, autodiff = autodiff)
end
Loading