Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/estimator/internal_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,19 @@ end

State function ``\mathbf{f̂}`` of [`InternalModel`](@ref) for [`NonLinModel`](@ref).

It calls `model.solver_f!(x̂0next, k0, x̂0, u0 ,d0, model.p)` directly since this estimator
does not augment the states.
It calls [`f!`](@ref) directly since this estimator does not augment the states.
"""
function f̂!(x̂0next, _ , k0, ::InternalModel, model::NonLinModel, x̂0, u0, d0)
return model.solver_f!(x̂0next, k0, x̂0, u0, d0, model.p)
return f!(x̂0next, k0, model, x̂0, u0, d0, model.p)
end

@doc raw"""
ĥ!(ŷ0, estim::InternalModel, model::NonLinModel, x̂0, d0)

Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.solver_h!`.
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls [`h!`](@ref).
"""
function ĥ!(ŷ0, ::InternalModel, model::NonLinModel, x̂0, d0)
return model.solver_h!(ŷ0, x̂0, d0, model.p)
return h!(ŷ0, model, x̂0, d0, model.p)
end


Expand Down
18 changes: 10 additions & 8 deletions src/model/linearization.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
get_linearization_func(
NT, solver_f!, solver_h!, nu, nx, ny, nd, ns, p, solver, backend
NT, f!, h!, Ts, nu, nx, ny, nd, ns, p, solver, backend
) -> linfunc!

Return `linfunc!` function that computes Jacobians of `solver_f!` and `solver_h!` functions.
Return `linfunc!` function that computes Jacobians of `f!` and `h!` functions.

The function has the following signature:
```
Expand All @@ -13,12 +13,14 @@ and it should modifies in-place all the arguments before `backend`. The `backend
is an `AbstractADType` object from `DifferentiationInterface`. The `cst_x`, `cst_u` and
`cst_d` are `DifferentiationInterface.Constant` objects with the linearization points.
"""
function get_linearization_func(NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, backend)
f_x!(xnext, x, k, u, d) = solver_f!(xnext, k, x, u, d, p)
f_u!(xnext, u, k, x, d) = solver_f!(xnext, k, x, u, d, p)
f_d!(xnext, d, k, x, u) = solver_f!(xnext, k, x, u, d, p)
h_x!(y, x, d) = solver_h!(y, x, d, p)
h_d!(y, d, x) = solver_h!(y, x, d, p)
function get_linearization_func(
NT, f!::F, h!::H, Ts, nu, nx, ny, nd, p, solver, backend
) where {F<:Function, H<:Function}
f_x!(xnext, x, k, u, d) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
f_u!(xnext, u, k, x, d) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
f_d!(xnext, d, k, x, u) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
h_x!(y, x, d) = h!(y, x, d, p)
h_d!(y, d, x) = h!(y, x, d, p)
strict = Val(true)
xnext = zeros(NT, nx)
y = zeros(NT, ny)
Expand Down
71 changes: 47 additions & 24 deletions src/model/nonlinmodel.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
"Abstract supertype of all differential equation solvers."
abstract type DiffSolver end

"Empty solver for nonlinear discrete-time models."
struct EmptySolver <: DiffSolver
ni::Int # number of intermediate stages
EmptySolver() = new(-1)
end

"Call `f!` function directly for discrete-time models (no solver)."
function solver_f!(xnext, _ , f!::F, _ , ::EmptySolver, x, u, d, p) where F
return f!(xnext, x, u, d, p)
end

Base.show(io::IO, ::EmptySolver) = print(io, "Empty differential equation solver.")

struct NonLinModel{
NT<:Real,
F<:Function,
H<:Function,
PT<:Any,
DS<:DiffSolver,
F <:Function,
H <:Function,
PT<:Any,
JB<:AbstractADType,
LF<:Function
} <: SimModel{NT}
x0::Vector{NT}
solver_f!::F
solver_h!::H
p::PT
solver::DS
f!::F
h!::H
p::PT
Ts::NT
t::Vector{NT}
nu::Int
Expand All @@ -32,14 +48,14 @@ struct NonLinModel{
linfunc!::LF
buffer::SimModelBuffer{NT}
function NonLinModel{NT}(
solver_f!::F, solver_h!::H, Ts, nu, nx, ny, nd,
p::PT, solver::DS, jacobian::JB, linfunc!::LF
solver::DS, f!::F, h!::H, Ts, nu, nx, ny, nd,
p::PT, jacobian::JB, linfunc!::LF
) where {
NT<:Real,
DS<:DiffSolver,
F<:Function,
H<:Function,
PT<:Any,
DS<:DiffSolver,
JB<:AbstractADType,
LF<:Function
}
Expand All @@ -58,11 +74,11 @@ struct NonLinModel{
ni = solver.ni
nk = nx*(ni+1)
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, ni)
return new{NT, F, H, PT, DS, JB, LF}(
return new{NT, DS, F, H, PT, JB, LF}(
x0,
solver_f!, solver_h!,
p,
solver,
solver,
f!, h!,
p,
Ts, t,
nu, nx, ny, nd, nk,
uop, yop, dop, xop, fop,
Expand Down Expand Up @@ -176,12 +192,11 @@ function NonLinModel{NT}(
) where {NT<:Real}
isnothing(solver) && (solver=EmptySolver())
f!, h! = get_mutating_functions(NT, f, h)
solver_f!, solver_h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
linfunc! = get_linearization_func(
NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, jacobian
NT, f!, h!, Ts, nu, nx, ny, nd, p, solver, jacobian
)
return NonLinModel{NT}(
solver_f!, solver_h!, Ts, nu, nx, ny, nd, p, solver, jacobian, linfunc!
solver, f!, h!, Ts, nu, nx, ny, nd, p, jacobian, linfunc!
)
end

Expand Down Expand Up @@ -270,22 +285,30 @@ Call [`linearize(model; x, u, d)`](@ref) and return the resulting linear model.
"""
LinModel(model::NonLinModel; kwargs...) = linearize(model; kwargs...)


"""
f!(x0next, k0, model::NonLinModel, x0, u0, d0, p)

Call `model.solver_f!(x0next, k0, x0, u0, d0, p)` for [`NonLinModel`](@ref).
Compute `x0next` using the [`DiffSolver`](@ref) in `model.solver` and `model.f!`.

The method mutate `x0next` and `k0` arguments in-place. The latter is used to store the
intermediate stage values of `model.solver` [`DiffSolver`](@ref).
The method mutates `x0next` and `k0` arguments in-place. The latter is used to store the
intermediate stage values of the solver.
"""
f!(x0next, k0, model::NonLinModel, x0, u0, d0, p) = model.solver_f!(x0next, k0, x0, u0, d0, p)
function f!(x0next, k0, model::NonLinModel, x0, u0, d0, p)
return solver_f!(x0next, k0, model.f!, model.Ts, model.solver, x0, u0, d0, p)
end

"""
h!(y0, model::NonLinModel, x0, d0, p)

Call `model.solver_h!(y0, x0, d0, p)` for [`NonLinModel`](@ref).
Compute `y0` by calling `model.h!` directly for [`NonLinModel`](@ref).
"""
h!(y0, model::NonLinModel, x0, d0, p) = model.solver_h!(y0, x0, d0, p)
h!(y0, model::NonLinModel, x0, d0, p) = model.h!(y0, x0, d0, p)

include("solver.jl")

detailstr(model::NonLinModel) = ", $(typeof(model.solver).name.name)($(model.solver.order)) solver"
detailstr(::NonLinModel{<:Real, <:Function, <:Function, <:Any, <:EmptySolver}) = ", empty solver"
function detailstr(model::NonLinModel{<:Real, <:RungeKutta{N}}) where N
return ", $(nameof(typeof(model.solver)))($N) solver"
end
detailstr(::NonLinModel{<:Real, <:EmptySolver}) = ", empty solver"
detailstr(::NonLinModel) = ""
134 changes: 43 additions & 91 deletions src/model/solver.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,5 @@
"Abstract supertype of all differential equation solvers."
abstract type DiffSolver end

"Empty solver for nonlinear discrete-time models."
struct EmptySolver <: DiffSolver
ni::Int # number of intermediate stages
EmptySolver() = new(-1)
end

"""
get_solver_functions(NT::DataType, solver::EmptySolver, f!, h!, Ts, nu, nx, ny, nd)

Get `solver_f!` and `solver_h!` functions for the `EmptySolver` (discrete models).

The functions should have the following signature:
```
solver_f!(xnext, k, x, u, d, p) -> nothing
solver_h!(y, x, d, p) -> nothing
```
in which `xnext`, `k` and `y` arguments are mutated in-place. The `k` argument is
a vector of `nx*(solver.ni+1)` elements to store the solver intermediate stage values (and
also the current state value for when `supersample ≠ 1`).
"""
function get_solver_functions(::DataType, ::EmptySolver, f!, h!, _ , _ , _ , _ , _ )
solver_f!(xnext, _ , x, u, d, p) = f!(xnext, x, u, d, p)
solver_h! = h!
return solver_f!, solver_h!
end

function Base.show(io::IO, solver::EmptySolver)
print(io, "Empty differential equation solver.")
end

struct RungeKutta <: DiffSolver
struct RungeKutta{N} <: DiffSolver
ni::Int # number of intermediate stages
order::Int # order of the method
supersample::Int # number of internal steps
function RungeKutta(order::Int, supersample::Int)
if order ≠ 4 && order ≠ 1
Expand All @@ -46,7 +12,7 @@ struct RungeKutta <: DiffSolver
throw(ArgumentError("supersample must be greater than 0"))
end
ni = order # only true for order ≤ 4 with RungeKutta
return new(ni, order, supersample)
return new{order}(ni, supersample)
end
end

Expand All @@ -61,60 +27,29 @@ This solver is allocation-free if the `f!` and `h!` functions do not allocate.
"""
RungeKutta(order::Int=4; supersample::Int=1) = RungeKutta(order, supersample)

"Get `solve_f!` and `solver_h!` functions for the explicit Runge-Kutta solvers."
function get_solver_functions(NT::DataType, solver::RungeKutta, f!, h!, Ts, _ , nx, _ , _ )
solver_f! = if solver.order == 4
get_rk4_function(NT, solver, f!, Ts, nx)
elseif solver.order == 1
get_euler_function(NT, solver, f!, Ts, nx)
else
throw(ArgumentError("only 1st and 4th order Runge-Kutta is supported."))
"Solve the differential equation with the 4th order Runge-Kutta method."
function solver_f!(xnext, k, f!::F, Ts, solver::RungeKutta{4}, x, u, d, p) where F
supersample = solver.supersample
Ts_inner = Ts/supersample
nx = length(x)
xcurr = @views k[1:nx]
k1 = @views k[(1nx + 1):(2nx)]
k2 = @views k[(2nx + 1):(3nx)]
k3 = @views k[(3nx + 1):(4nx)]
k4 = @views k[(4nx + 1):(5nx)]
@. xcurr = x
for i=1:supersample
f!(k1, xcurr, u, d, p)
@. xnext = xcurr + k1 * Ts_inner/2
f!(k2, xnext, u, d, p)
@. xnext = xcurr + k2 * Ts_inner/2
f!(k3, xnext, u, d, p)
@. xnext = xcurr + k3 * Ts_inner
f!(k4, xnext, u, d, p)
@. xcurr = xcurr + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
end
solver_h! = h!
return solver_f!, solver_h!
end

"Get the f! function for the 4th order explicit Runge-Kutta solver."
function get_rk4_function(NT, solver, f!, Ts, nx)
Ts_inner = Ts/solver.supersample
function rk4_solver_f!(xnext, k, x, u, d, p)
xcurr = @views k[1:nx]
k1 = @views k[(1nx + 1):(2nx)]
k2 = @views k[(2nx + 1):(3nx)]
k3 = @views k[(3nx + 1):(4nx)]
k4 = @views k[(4nx + 1):(5nx)]
@. xcurr = x
for i=1:solver.supersample
f!(k1, xcurr, u, d, p)
@. xnext = xcurr + k1 * Ts_inner/2
f!(k2, xnext, u, d, p)
@. xnext = xcurr + k2 * Ts_inner/2
f!(k3, xnext, u, d, p)
@. xnext = xcurr + k3 * Ts_inner
f!(k4, xnext, u, d, p)
@. xcurr = xcurr + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
end
@. xnext = xcurr
return nothing
end
return rk4_solver_f!
end

"Get the f! function for the explicit Euler solver."
function get_euler_function(NT, solver, fc!, Ts, nx)
Ts_inner = Ts/solver.supersample
function euler_solver_f!(xnext, k, x, u, d, p)
xcurr = @views k[1:nx]
k1 = @views k[(1nx + 1):(2nx)]
@. xcurr = x
for i=1:solver.supersample
fc!(k1, xcurr, u, d, p)
@. xcurr = xcurr + k1 * Ts_inner
end
@. xnext = xcurr
return nothing
end
return euler_solver_f!
@. xnext = xcurr
return nothing
end

"""
Expand All @@ -126,7 +61,24 @@ This is an alias for `RungeKutta(1; supersample)`, see [`RungeKutta`](@ref).
"""
const ForwardEuler(;supersample=1) = RungeKutta(1; supersample)

function Base.show(io::IO, solver::RungeKutta)
N, n = solver.order, solver.supersample

"Solve the differential equation with the forward Euler method."
function solver_f!(xnext, k, f!::F, Ts, solver::RungeKutta{1}, x, u, d, p) where F
supersample = solver.supersample
Ts_inner = Ts/supersample
nx = length(x)
xcurr = @views k[1:nx]
k1 = @views k[(1nx + 1):(2nx)]
@. xcurr = x
for i=1:supersample
f!(k1, xcurr, u, d, p)
@. xcurr = xcurr + k1 * Ts_inner
end
@. xnext = xcurr
return nothing
end

function Base.show(io::IO, solver::RungeKutta{N}) where N
n = solver.supersample
print(io, "$(N)th order Runge-Kutta differential equation solver with $n supersamples.")
end
1 change: 0 additions & 1 deletion src/sim_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ to_mat(A::AbstractMatrix, _ ...) = A
to_mat(A::Real, dims...) = fill(A, dims)

include("model/linmodel.jl")
include("model/solver.jl")
include("model/linearization.jl")
include("model/nonlinmodel.jl")

Expand Down
Loading