Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions lib/ImplicitDiscreteSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,38 @@ authors = ["vyudu <[email protected]>"]
version = "1.2.0"

[deps]
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[extras]
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
[sources]
OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"}

[compat]
Test = "1.10.0"
AllocCheck = "0.2"
Aqua = "0.8.11"
DiffEqBase = "6.176"
JET = "0.9.18, 0.10.4"
NonlinearSolveFirstOrder = "1.9.0"
OrdinaryDiffEqCore = "1.29.0"
OrdinaryDiffEqSDIRK = "1.6.0"
Reexport = "1.2"
SciMLBase = "2.99"
SimpleNonlinearSolve = "2.7"
OrdinaryDiffEqCore = "1.29.0"
Aqua = "0.8.11"
SymbolicIndexingInterface = "0.3.38"
julia = "1.10"
JET = "0.9.18, 0.10.4"
Test = "1.10.0"
UnPack = "1.0.2"
AllocCheck = "0.2"
DiffEqBase = "6.176"
Reexport = "1.2"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["OrdinaryDiffEqSDIRK", "Test", "JET", "Aqua", "AllocCheck"]

[sources.OrdinaryDiffEqCore]
path = "../OrdinaryDiffEqCore"
14 changes: 5 additions & 9 deletions lib/ImplicitDiscreteSolve/src/ImplicitDiscreteSolve.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
module ImplicitDiscreteSolve

using SciMLBase
using SimpleNonlinearSolve
using NonlinearSolveFirstOrder
using UnPack
using SymbolicIndexingInterface: parameter_symbols
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache,
OrdinaryDiffEqConstantCache, get_fsalfirstlast, isfsal,
initialize!, perform_step!, isdiscretecache, isdiscretealg,
alg_order, beta2_default, beta1_default, dt_required,
_initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit
_initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit,
OrdinaryDiffEqNewtonAdaptiveAlgorithm, @muladd, @..,
AutoForwardDiff, _process_AD_choice, _unwrap_val

using Reexport
@reexport using SciMLBase

"""
IDSolve()

Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
"""
struct IDSolve <: OrdinaryDiffEqAlgorithm end

include("algorithms.jl")
include("cache.jl")
include("solve.jl")
include("alg_utils.jl")
Expand Down
21 changes: 21 additions & 0 deletions lib/ImplicitDiscreteSolve/src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
IDSolve()

First order solver for `ImplicitDiscreteSystems`.
"""
# struct IDSolve{CS, AD, NLS, FDT, ST, CJ} <:
struct IDSolve{NLS} <:
OrdinaryDiffEqAlgorithm
nlsolve::NLS
extrapolant::Symbol
controller::Symbol
end

function IDSolve(;
nlsolve = NewtonRaphson(), #NLNewton(),
extrapolant = :constant,
controller = :PI,
)

IDSolve{typeof(nlsolve)}(nlsolve, extrapolant, controller)
end
53 changes: 41 additions & 12 deletions lib/ImplicitDiscreteSolve/src/cache.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
mutable struct ImplicitDiscreteState{uType, pType, tType}
struct ImplicitDiscreteState{uType, pType, tType}
u::uType
p::pType
t_next::tType
t::tType
end

mutable struct IDSolveCache{uType} <: OrdinaryDiffEqMutableCache
mutable struct IDSolveCache{uType, cType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
state::ImplicitDiscreteState
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
z::uType
nlcache::cType
end

function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
state = ImplicitDiscreteState(isnothing(u) ? nothing : zero(u), p, t)
IDSolveCache(u, uprev, state, nothing)
state = ImplicitDiscreteState(zero(u), p, t)
f_nl = (resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning here to include the current u in the signature of this function, but no information on dt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dt is just a parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my question is simply, why is $dt$ (or tprev) not a parameter, but $uprev$ is part of the function signature?


u_len = length(u)
nlls = !isnothing(f.resid_prototype) && (length(f.resid_prototype) != u_len)
prob = if nlls
NonlinearLeastSquaresProblem{isinplace(f)}(
NonlinearFunction(f_nl; resid_prototype = f.resid_prototype),
u, state)
else
NonlinearProblem{isinplace(f)}(f_nl, u, state)
end

nlcache = init(prob, alg.nlsolve)

IDSolveCache(u, uprev, state.u, nlcache)
end

isdiscretecache(cache::IDSolveCache) = true

struct IDSolveConstantCache <: OrdinaryDiffEqConstantCache
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
end
# struct IDSolveConstantCache <: OrdinaryDiffEqConstantCache
# prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
# end

function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
state = ImplicitDiscreteState(isnothing(u) ? nothing : zero(u), p, t)
IDSolveCache(u, uprev, state, nothing)
state = ImplicitDiscreteState(zero(u), p, t)
f_nl = (u_next, p) -> f(u_next, p.u, p.p, p.t)

u_len = length(u)
nlls = !isnothing(f.resid_prototype) && (length(f.resid_prototype) != u_len)
prob = if nlls
NonlinearLeastSquaresProblem{isinplace(f)}(
NonlinearFunction(f_nl; resid_prototype = f.resid_prototype),
u, state)
else
NonlinearProblem{isinplace(f)}(f_nl, u, state)
end

nlcache = init(prob, alg.nlsolve)

# FIXME Use IDSolveConstantCache
IDSolveCache(u, uprev, state.u, nlcache)
end

get_fsalfirstlast(cache::IDSolveCache, rate_prototype) = (nothing, nothing)
64 changes: 30 additions & 34 deletions lib/ImplicitDiscreteSolve/src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,33 @@
# Remake the nonlinear problem, then update
function perform_step!(integrator, cache::IDSolveCache, repeat_step = false)
@unpack alg, u, uprev, dt, t, f, p = integrator
@unpack state, prob = cache
state.u .= uprev
state.t_next = t
prob = remake(prob, p = state)
(; alg, u, uprev, dt, t, f, p) = integrator

u = solve(prob, SimpleNewtonRaphson())
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, u.retcode)
integrator.u = u
end

function initialize!(integrator, cache::IDSolveCache)
integrator.u isa AbstractVector && (cache.state.u .= integrator.u)
cache.state.p = integrator.p
cache.state.t_next = integrator.t
f = integrator.f

_f = if isinplace(f)
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
else
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
# initial guess
if alg.extrapolant == :linear
@.. broadcast=false cache.z=integrator.uprev + dt * (integrator.uprev - integrator.uprev2)
else # :constant
cache.z .= integrator.u
end
u_len = isnothing(integrator.u) ? 0 : length(integrator.u)
nlls = !isnothing(f.resid_prototype) && (length(f.resid_prototype) != u_len)
state = ImplicitDiscreteState(cache.z, p, t)

prob = if nlls
NonlinearLeastSquaresProblem{isinplace(f)}(
NonlinearFunction(_f; resid_prototype = f.resid_prototype),
cache.state.u, cache.state)
else
NonlinearProblem{isinplace(f)}(_f, cache.state.u, cache.state)
# nonlinear solve step
SciMLBase.reinit!(cache.nlcache, p=state)
# TODO compute convergence rate estimate
# for i in 1:10
# step!(cache.nlcache)
# # ...
# end
solve!(cache.nlcache)
if cache.nlcache.retcode != ReturnCode.Success
integrator.force_stepfail = true
return
end
cache.prob = prob

# Accept step
u .= cache.nlcache.u
end

function initialize!(integrator, cache::IDSolveCache)
integrator.u isa AbstractVector && (cache.z .= integrator.u)
end

function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
Expand All @@ -43,13 +38,14 @@ function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
_initialize_dae!(integrator, prob,
OverrideInit(atol), x)
else
@unpack u, p, t, f = integrator
(; u, p, t, f) = integrator

initstate = ImplicitDiscreteState(u, p, t)

_f = if isinplace(f)
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t)
else
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
(u_next, p) -> f(u_next, p.u, p.p, p.t)
end

nlls = !isnothing(f.resid_prototype) &&
Expand All @@ -60,7 +56,7 @@ function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
else
NonlinearProblem{isinplace(f)}(_f, u, initstate)
end
sol = solve(prob, SimpleNewtonRaphson())
sol = solve(prob, integrator.alg.nlsolve)
if sol.retcode == ReturnCode.Success
integrator.u = sol
else
Expand Down
24 changes: 12 additions & 12 deletions lib/ImplicitDiscreteSolve/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ end
end
end

@testset "Handle nothing in u0" begin
function empty(u_next, u, p, t)
nothing
end
# @testset "Handle nothing in u0" begin
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have trouble understanding the purpose of this test. What exactly is the practical scenario here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use u=nothing to represent length 0 u in a type stable manner (as op0osed to a Vector which is length 0 at runtime)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I do not understand the point - do you imply here that a zero-length solution vector not type stable and this is why we resort to "nothing"?

When I try to get this test running, I get errors from an alias analysis function not being dispatched for nothing in NonlinearSolve (https://github.com/SciML/NonlinearSolve.jl/blob/ac9344f9359833282e443c4479427ad9ce3311dd/lib/NonlinearSolveFirstOrder/src/solve.jl#L157).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace functions with empty return types are also not correctly handled.

using NonlinearSolveFirstOrder
function f(u, p)
    nothing
end
prob = NonlinearProblem{false}(f, Float64[], nothing)
iter = init(prob, NewtonRaphson())

errors when the termination cache is built, because the increment type cannot be derived.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the point is that zero length u tends to cause problems (e.g. solvers will try to look at the first element of the state to find a type), so this way you can skip the solve process since nothing interesting can happen if you don't have any state. See https://github.com/SciML/DiffEqBase.jl/blob/3667bdbdc85489f7b296316df7f4c440519e82f6/src/solve.jl#L31 for how this gets handled for ODEs/DAEs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make more sense from a software engineering stand point to replace such isa statements with dispatchable functions to make the code more extensible?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quite possibly. DiffEqBase is not my favorite code organization.

# function empty(u_next, u, p, t)
# nothing
# end

tsteps = 5
u0 = nothing
idprob = ImplicitDiscreteProblem(empty, u0, (0, tsteps), [])
@test_nowarn integ = init(idprob, IDSolve())
end
# tsteps = 5
# u0 = nothing
# idprob = ImplicitDiscreteProblem(empty, u0, (0, tsteps), [])
# @test_nowarn integ = init(idprob, IDSolve())
# end

@testset "Create NonlinearLeastSquaresProblem" begin
function over(u_next, u, p, t)
Expand All @@ -92,23 +92,23 @@ end
idprob = ImplicitDiscreteProblem(
ImplicitDiscreteFunction(over, resid_prototype = zeros(3)), u0, (0, tsteps), [])
integ = init(idprob, IDSolve())
@test integ.cache.prob isa NonlinearLeastSquaresProblem
@test integ.cache.nlcache.prob isa NonlinearLeastSquaresProblem

function under(u_next, u, p, t)
[u_next[1] - u_next[2] - 1]
end
idprob = ImplicitDiscreteProblem(
ImplicitDiscreteFunction(under; resid_prototype = zeros(1)), u0, (0, tsteps), [])
integ = init(idprob, IDSolve())
@test integ.cache.prob isa NonlinearLeastSquaresProblem
@test integ.cache.nlcache.prob isa NonlinearLeastSquaresProblem

function full(u_next, u, p, t)
[u_next[1]^2 - 3, u_next[2] - u[1]]
end
idprob = ImplicitDiscreteProblem(
ImplicitDiscreteFunction(full; resid_prototype = zeros(2)), u0, (0, tsteps), [])
integ = init(idprob, IDSolve())
@test integ.cache.prob isa NonlinearProblem
@test integ.cache.nlcache.prob isa NonlinearProblem
end

@testset "InitialFailure thrown" begin
Expand Down
Loading