Skip to content

Commit 9b87de8

Browse files
authored
Merge pull request #233 from JuliaControl/continuous_ss_field
changed: keep `f!` and `h!` functions available
2 parents 42d29e5 + 0cd7d18 commit 9b87de8

File tree

7 files changed

+133
-158
lines changed

7 files changed

+133
-158
lines changed

src/estimator/internal_model.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,19 @@ end
170170
171171
State function ``\mathbf{f̂}`` of [`InternalModel`](@ref) for [`NonLinModel`](@ref).
172172
173-
It calls `model.solver_f!(x̂0next, k0, x̂0, u0 ,d0, model.p)` directly since this estimator
174-
does not augment the states.
173+
It calls [`f!`](@ref) directly since this estimator does not augment the states.
175174
"""
176175
function f̂!(x̂0next, _ , k0, ::InternalModel, model::NonLinModel, x̂0, u0, d0)
177-
return model.solver_f!(x̂0next, k0, x̂0, u0, d0, model.p)
176+
return f!(x̂0next, k0, model, x̂0, u0, d0, model.p)
178177
end
179178

180179
@doc raw"""
181180
ĥ!(ŷ0, estim::InternalModel, model::NonLinModel, x̂0, d0)
182181
183-
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.solver_h!`.
182+
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls [`h!`](@ref).
184183
"""
185184
function ĥ!(ŷ0, ::InternalModel, model::NonLinModel, x̂0, d0)
186-
return model.solver_h!(ŷ0, x̂0, d0, model.p)
185+
return h!(ŷ0, model, x̂0, d0, model.p)
187186
end
188187

189188

src/model/linearization.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
get_linearization_func(
3-
NT, solver_f!, solver_h!, nu, nx, ny, nd, ns, p, solver, backend
3+
NT, f!, h!, Ts, nu, nx, ny, nd, ns, p, solver, backend
44
) -> linfunc!
55
6-
Return `linfunc!` function that computes Jacobians of `solver_f!` and `solver_h!` functions.
6+
Return `linfunc!` function that computes Jacobians of `f!` and `h!` functions.
77
88
The function has the following signature:
99
```
@@ -13,12 +13,14 @@ and it should modifies in-place all the arguments before `backend`. The `backend
1313
is an `AbstractADType` object from `DifferentiationInterface`. The `cst_x`, `cst_u` and
1414
`cst_d` are `DifferentiationInterface.Constant` objects with the linearization points.
1515
"""
16-
function get_linearization_func(NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, backend)
17-
f_x!(xnext, x, k, u, d) = solver_f!(xnext, k, x, u, d, p)
18-
f_u!(xnext, u, k, x, d) = solver_f!(xnext, k, x, u, d, p)
19-
f_d!(xnext, d, k, x, u) = solver_f!(xnext, k, x, u, d, p)
20-
h_x!(y, x, d) = solver_h!(y, x, d, p)
21-
h_d!(y, d, x) = solver_h!(y, x, d, p)
16+
function get_linearization_func(
17+
NT, f!::F, h!::H, Ts, nu, nx, ny, nd, p, solver, backend
18+
) where {F<:Function, H<:Function}
19+
f_x!(xnext, x, k, u, d) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
20+
f_u!(xnext, u, k, x, d) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
21+
f_d!(xnext, d, k, x, u) = solver_f!(xnext, k, f!, Ts, solver, x, u, d, p)
22+
h_x!(y, x, d) = h!(y, x, d, p)
23+
h_d!(y, d, x) = h!(y, x, d, p)
2224
strict = Val(true)
2325
xnext = zeros(NT, nx)
2426
y = zeros(NT, ny)

src/model/nonlinmodel.jl

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
1+
"Abstract supertype of all differential equation solvers."
2+
abstract type DiffSolver end
3+
4+
"Empty solver for nonlinear discrete-time models."
5+
struct EmptySolver <: DiffSolver
6+
ni::Int # number of intermediate stages
7+
EmptySolver() = new(-1)
8+
end
9+
10+
"Call `f!` function directly for discrete-time models (no solver)."
11+
function solver_f!(xnext, _ , f!::F, _ , ::EmptySolver, x, u, d, p) where F
12+
return f!(xnext, x, u, d, p)
13+
end
14+
15+
Base.show(io::IO, ::EmptySolver) = print(io, "Empty differential equation solver.")
16+
117
struct NonLinModel{
218
NT<:Real,
3-
F<:Function,
4-
H<:Function,
5-
PT<:Any,
619
DS<:DiffSolver,
20+
F <:Function,
21+
H <:Function,
22+
PT<:Any,
723
JB<:AbstractADType,
824
LF<:Function
925
} <: SimModel{NT}
1026
x0::Vector{NT}
11-
solver_f!::F
12-
solver_h!::H
13-
p::PT
1427
solver::DS
28+
f!::F
29+
h!::H
30+
p::PT
1531
Ts::NT
1632
t::Vector{NT}
1733
nu::Int
@@ -32,14 +48,14 @@ struct NonLinModel{
3248
linfunc!::LF
3349
buffer::SimModelBuffer{NT}
3450
function NonLinModel{NT}(
35-
solver_f!::F, solver_h!::H, Ts, nu, nx, ny, nd,
36-
p::PT, solver::DS, jacobian::JB, linfunc!::LF
51+
solver::DS, f!::F, h!::H, Ts, nu, nx, ny, nd,
52+
p::PT, jacobian::JB, linfunc!::LF
3753
) where {
3854
NT<:Real,
55+
DS<:DiffSolver,
3956
F<:Function,
4057
H<:Function,
4158
PT<:Any,
42-
DS<:DiffSolver,
4359
JB<:AbstractADType,
4460
LF<:Function
4561
}
@@ -58,11 +74,11 @@ struct NonLinModel{
5874
ni = solver.ni
5975
nk = nx*(ni+1)
6076
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, ni)
61-
return new{NT, F, H, PT, DS, JB, LF}(
77+
return new{NT, DS, F, H, PT, JB, LF}(
6278
x0,
63-
solver_f!, solver_h!,
64-
p,
65-
solver,
79+
solver,
80+
f!, h!,
81+
p,
6682
Ts, t,
6783
nu, nx, ny, nd, nk,
6884
uop, yop, dop, xop, fop,
@@ -176,12 +192,11 @@ function NonLinModel{NT}(
176192
) where {NT<:Real}
177193
isnothing(solver) && (solver=EmptySolver())
178194
f!, h! = get_mutating_functions(NT, f, h)
179-
solver_f!, solver_h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
180195
linfunc! = get_linearization_func(
181-
NT, solver_f!, solver_h!, nu, nx, ny, nd, p, solver, jacobian
196+
NT, f!, h!, Ts, nu, nx, ny, nd, p, solver, jacobian
182197
)
183198
return NonLinModel{NT}(
184-
solver_f!, solver_h!, Ts, nu, nx, ny, nd, p, solver, jacobian, linfunc!
199+
solver, f!, h!, Ts, nu, nx, ny, nd, p, jacobian, linfunc!
185200
)
186201
end
187202

@@ -270,22 +285,30 @@ Call [`linearize(model; x, u, d)`](@ref) and return the resulting linear model.
270285
"""
271286
LinModel(model::NonLinModel; kwargs...) = linearize(model; kwargs...)
272287

288+
273289
"""
274290
f!(x0next, k0, model::NonLinModel, x0, u0, d0, p)
275291
276-
Call `model.solver_f!(x0next, k0, x0, u0, d0, p)` for [`NonLinModel`](@ref).
292+
Compute `x0next` using the [`DiffSolver`](@ref) in `model.solver` and `model.f!`.
277293
278-
The method mutate `x0next` and `k0` arguments in-place. The latter is used to store the
279-
intermediate stage values of `model.solver` [`DiffSolver`](@ref).
294+
The method mutates `x0next` and `k0` arguments in-place. The latter is used to store the
295+
intermediate stage values of the solver.
280296
"""
281-
f!(x0next, k0, model::NonLinModel, x0, u0, d0, p) = model.solver_f!(x0next, k0, x0, u0, d0, p)
297+
function f!(x0next, k0, model::NonLinModel, x0, u0, d0, p)
298+
return solver_f!(x0next, k0, model.f!, model.Ts, model.solver, x0, u0, d0, p)
299+
end
282300

283301
"""
284302
h!(y0, model::NonLinModel, x0, d0, p)
285303
286-
Call `model.solver_h!(y0, x0, d0, p)` for [`NonLinModel`](@ref).
304+
Compute `y0` by calling `model.h!` directly for [`NonLinModel`](@ref).
287305
"""
288-
h!(y0, model::NonLinModel, x0, d0, p) = model.solver_h!(y0, x0, d0, p)
306+
h!(y0, model::NonLinModel, x0, d0, p) = model.h!(y0, x0, d0, p)
307+
308+
include("solver.jl")
289309

290-
detailstr(model::NonLinModel) = ", $(typeof(model.solver).name.name)($(model.solver.order)) solver"
291-
detailstr(::NonLinModel{<:Real, <:Function, <:Function, <:Any, <:EmptySolver}) = ", empty solver"
310+
function detailstr(model::NonLinModel{<:Real, <:RungeKutta{N}}) where N
311+
return ", $(nameof(typeof(model.solver)))($N) solver"
312+
end
313+
detailstr(::NonLinModel{<:Real, <:EmptySolver}) = ", empty solver"
314+
detailstr(::NonLinModel) = ""

src/model/solver.jl

Lines changed: 43 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,5 @@
1-
"Abstract supertype of all differential equation solvers."
2-
abstract type DiffSolver end
3-
4-
"Empty solver for nonlinear discrete-time models."
5-
struct EmptySolver <: DiffSolver
6-
ni::Int # number of intermediate stages
7-
EmptySolver() = new(-1)
8-
end
9-
10-
"""
11-
get_solver_functions(NT::DataType, solver::EmptySolver, f!, h!, Ts, nu, nx, ny, nd)
12-
13-
Get `solver_f!` and `solver_h!` functions for the `EmptySolver` (discrete models).
14-
15-
The functions should have the following signature:
16-
```
17-
solver_f!(xnext, k, x, u, d, p) -> nothing
18-
solver_h!(y, x, d, p) -> nothing
19-
```
20-
in which `xnext`, `k` and `y` arguments are mutated in-place. The `k` argument is
21-
a vector of `nx*(solver.ni+1)` elements to store the solver intermediate stage values (and
22-
also the current state value for when `supersample ≠ 1`).
23-
"""
24-
function get_solver_functions(::DataType, ::EmptySolver, f!, h!, _ , _ , _ , _ , _ )
25-
solver_f!(xnext, _ , x, u, d, p) = f!(xnext, x, u, d, p)
26-
solver_h! = h!
27-
return solver_f!, solver_h!
28-
end
29-
30-
function Base.show(io::IO, solver::EmptySolver)
31-
print(io, "Empty differential equation solver.")
32-
end
33-
34-
struct RungeKutta <: DiffSolver
1+
struct RungeKutta{N} <: DiffSolver
352
ni::Int # number of intermediate stages
36-
order::Int # order of the method
373
supersample::Int # number of internal steps
384
function RungeKutta(order::Int, supersample::Int)
395
if order 4 && order 1
@@ -46,7 +12,7 @@ struct RungeKutta <: DiffSolver
4612
throw(ArgumentError("supersample must be greater than 0"))
4713
end
4814
ni = order # only true for order ≤ 4 with RungeKutta
49-
return new(ni, order, supersample)
15+
return new{order}(ni, supersample)
5016
end
5117
end
5218

@@ -61,60 +27,29 @@ This solver is allocation-free if the `f!` and `h!` functions do not allocate.
6127
"""
6228
RungeKutta(order::Int=4; supersample::Int=1) = RungeKutta(order, supersample)
6329

64-
"Get `solve_f!` and `solver_h!` functions for the explicit Runge-Kutta solvers."
65-
function get_solver_functions(NT::DataType, solver::RungeKutta, f!, h!, Ts, _ , nx, _ , _ )
66-
solver_f! = if solver.order == 4
67-
get_rk4_function(NT, solver, f!, Ts, nx)
68-
elseif solver.order == 1
69-
get_euler_function(NT, solver, f!, Ts, nx)
70-
else
71-
throw(ArgumentError("only 1st and 4th order Runge-Kutta is supported."))
30+
"Solve the differential equation with the 4th order Runge-Kutta method."
31+
function solver_f!(xnext, k, f!::F, Ts, solver::RungeKutta{4}, x, u, d, p) where F
32+
supersample = solver.supersample
33+
Ts_inner = Ts/supersample
34+
nx = length(x)
35+
xcurr = @views k[1:nx]
36+
k1 = @views k[(1nx + 1):(2nx)]
37+
k2 = @views k[(2nx + 1):(3nx)]
38+
k3 = @views k[(3nx + 1):(4nx)]
39+
k4 = @views k[(4nx + 1):(5nx)]
40+
@. xcurr = x
41+
for i=1:supersample
42+
f!(k1, xcurr, u, d, p)
43+
@. xnext = xcurr + k1 * Ts_inner/2
44+
f!(k2, xnext, u, d, p)
45+
@. xnext = xcurr + k2 * Ts_inner/2
46+
f!(k3, xnext, u, d, p)
47+
@. xnext = xcurr + k3 * Ts_inner
48+
f!(k4, xnext, u, d, p)
49+
@. xcurr = xcurr + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
7250
end
73-
solver_h! = h!
74-
return solver_f!, solver_h!
75-
end
76-
77-
"Get the f! function for the 4th order explicit Runge-Kutta solver."
78-
function get_rk4_function(NT, solver, f!, Ts, nx)
79-
Ts_inner = Ts/solver.supersample
80-
function rk4_solver_f!(xnext, k, x, u, d, p)
81-
xcurr = @views k[1:nx]
82-
k1 = @views k[(1nx + 1):(2nx)]
83-
k2 = @views k[(2nx + 1):(3nx)]
84-
k3 = @views k[(3nx + 1):(4nx)]
85-
k4 = @views k[(4nx + 1):(5nx)]
86-
@. xcurr = x
87-
for i=1:solver.supersample
88-
f!(k1, xcurr, u, d, p)
89-
@. xnext = xcurr + k1 * Ts_inner/2
90-
f!(k2, xnext, u, d, p)
91-
@. xnext = xcurr + k2 * Ts_inner/2
92-
f!(k3, xnext, u, d, p)
93-
@. xnext = xcurr + k3 * Ts_inner
94-
f!(k4, xnext, u, d, p)
95-
@. xcurr = xcurr + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
96-
end
97-
@. xnext = xcurr
98-
return nothing
99-
end
100-
return rk4_solver_f!
101-
end
102-
103-
"Get the f! function for the explicit Euler solver."
104-
function get_euler_function(NT, solver, fc!, Ts, nx)
105-
Ts_inner = Ts/solver.supersample
106-
function euler_solver_f!(xnext, k, x, u, d, p)
107-
xcurr = @views k[1:nx]
108-
k1 = @views k[(1nx + 1):(2nx)]
109-
@. xcurr = x
110-
for i=1:solver.supersample
111-
fc!(k1, xcurr, u, d, p)
112-
@. xcurr = xcurr + k1 * Ts_inner
113-
end
114-
@. xnext = xcurr
115-
return nothing
116-
end
117-
return euler_solver_f!
51+
@. xnext = xcurr
52+
return nothing
11853
end
11954

12055
"""
@@ -126,7 +61,24 @@ This is an alias for `RungeKutta(1; supersample)`, see [`RungeKutta`](@ref).
12661
"""
12762
const ForwardEuler(;supersample=1) = RungeKutta(1; supersample)
12863

129-
function Base.show(io::IO, solver::RungeKutta)
130-
N, n = solver.order, solver.supersample
64+
65+
"Solve the differential equation with the forward Euler method."
66+
function solver_f!(xnext, k, f!::F, Ts, solver::RungeKutta{1}, x, u, d, p) where F
67+
supersample = solver.supersample
68+
Ts_inner = Ts/supersample
69+
nx = length(x)
70+
xcurr = @views k[1:nx]
71+
k1 = @views k[(1nx + 1):(2nx)]
72+
@. xcurr = x
73+
for i=1:supersample
74+
f!(k1, xcurr, u, d, p)
75+
@. xcurr = xcurr + k1 * Ts_inner
76+
end
77+
@. xnext = xcurr
78+
return nothing
79+
end
80+
81+
function Base.show(io::IO, solver::RungeKutta{N}) where N
82+
n = solver.supersample
13183
print(io, "$(N)th order Runge-Kutta differential equation solver with $n supersamples.")
13284
end

src/sim_model.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ to_mat(A::AbstractMatrix, _ ...) = A
365365
to_mat(A::Real, dims...) = fill(A, dims)
366366

367367
include("model/linmodel.jl")
368-
include("model/solver.jl")
369368
include("model/linearization.jl")
370369
include("model/nonlinmodel.jl")
371370

0 commit comments

Comments
 (0)