Skip to content

Commit 8048dbe

Browse files
committed
wip: RungeKutta(4) now works without DiffCache
1 parent 2ea5887 commit 8048dbe

File tree

3 files changed

+54
-48
lines changed

3 files changed

+54
-48
lines changed

src/model/nonlinmodel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct NonLinModel{
5353
xname = ["\$x_{$i}\$" for i in 1:nx]
5454
x0 = zeros(NT, nx)
5555
t = zeros(NT, 1)
56-
buffer = SimModelBuffer{NT}(nu, nx, ny, nd)
56+
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, solver.ns)
5757
return new{NT, F, H, PT, DS, JB, LF}(
5858
x0,
5959
f!, h!,
@@ -263,7 +263,7 @@ Call [`linearize(model; x, u, d)`](@ref) and return the resulting linear model.
263263
LinModel(model::NonLinModel; kwargs...) = linearize(model; kwargs...)
264264

265265
"Call `model.f!(xnext0, x0, u0, d0, p)` for [`NonLinModel`](@ref)."
266-
f!(xnext0, model::NonLinModel, x0, u0, d0, p) = model.f!(xnext0, x0, u0, d0, p)
266+
f!(xnext0, model::NonLinModel, x0, u0, d0, p) = model.f!(xnext0, model.buffer.K, x0, u0, d0, p)
267267

268268
"Call `model.h!(y0, x0, d0, p)` for [`NonLinModel`](@ref)."
269269
h!(y0, model::NonLinModel, x0, d0, p) = model.h!(y0, x0, d0, p)

src/model/solver.jl

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,38 @@ abstract type DiffSolver end
33

44
"Empty solver for nonlinear discrete-time models."
55
struct EmptySolver <: DiffSolver
6-
ns::Int
6+
ns::Int # number of stages
77
EmptySolver() = new(0)
88
end
99

10-
function get_solver_functions(NT::DataType, ::EmptySolver, f!, h!, _ ... )
11-
f_solver!(xnext, _ , x, u, d, p) = f!(xnext, x, u, d, p)
12-
return f!, h!
10+
"""
11+
get_solver_functions(NT::DataType, solver::EmptySolver, f!, h!, Ts, nu, nx, ny, nd)
12+
13+
Get `solver_f!` and `solver_h!` functions for the `EmptySolver` (discrete models).
14+
15+
The functions should have the following signature:
16+
```
17+
solver_f!(xnext, K, x, u, d, p) -> nothing
18+
solver_h!(y, x, d, p) -> nothing
19+
```
20+
in which `xnext`, `K` and `y` arguments should be mutated in-place. The `K` argument is
21+
a vector of `nx*(solver.ns+1)` elements to store the solver intermediary stage values,
22+
and also the current state value when `supersample ≠ 1`.
23+
"""
24+
function get_solver_functions(::DataType, ::EmptySolver, f!, h!, _ , _ , _ , _ )
25+
solver_f!(xnext, _ , x, u, d, p) = f!(xnext, x, u, d, p)
26+
solver_h! = h!
27+
return solver_f!, solver_h!
1328
end
1429

1530
function Base.show(io::IO, solver::EmptySolver)
1631
print(io, "Empty differential equation solver.")
1732
end
1833

1934
struct RungeKutta <: DiffSolver
20-
ns::Int
21-
order::Int
22-
supersample::Int
35+
ns::Int # number of stages
36+
order::Int # order of the method
37+
supersample::Int # number of internal steps
2338
function RungeKutta(order::Int, supersample::Int)
2439
if order 4 && order 1
2540
throw(ArgumentError("only 1st and 4th order Runge-Kutta is supported."))
@@ -30,7 +45,7 @@ struct RungeKutta <: DiffSolver
3045
if supersample < 1
3146
throw(ArgumentError("supersample must be greater than 0"))
3247
end
33-
ns = order # only true for order ≤ 4
48+
ns = order # only true for order ≤ 4 with RungeKutta
3449
return new(ns, order, supersample)
3550
end
3651
end
@@ -46,59 +61,50 @@ This solver is allocation-free if the `f!` and `h!` functions do not allocate.
4661
"""
4762
RungeKutta(order::Int=4; supersample::Int=1) = RungeKutta(order, supersample)
4863

49-
"Get the `f!` and `h!` functions for the explicit Runge-Kutta solvers."
50-
function get_solver_functions(NT::DataType, solver::RungeKutta, fc!, hc!, Ts, _ , nx, _ , _ )
64+
"Get `solve_f!` and `solver_h!` functions for the explicit Runge-Kutta solvers."
65+
function get_solver_functions(NT::DataType, solver::RungeKutta, f!, h!, Ts, _ , nx, _ , _ )
5166
Nc = nx + 2
52-
f! = if solver.order==4
53-
get_rk4_function(NT, solver, fc!, Ts, nx, Nc)
67+
solver_f! = if solver.order==4
68+
get_rk4_function(NT, solver, f!, Ts, nx, Nc)
5469
elseif solver.order==1
55-
get_euler_function(NT, solver, fc!, Ts, nx, Nc)
70+
get_euler_function(NT, solver, f!, Ts, nx, Nc)
5671
else
5772
throw(ArgumentError("only 1st and 4th order Runge-Kutta is supported."))
5873
end
59-
h! = hc!
60-
return f!, h!
74+
solver_h! = h!
75+
return solver_f!, solver_h!
6176
end
6277

6378
"Get the f! function for the 4th order explicit Runge-Kutta solver."
64-
function get_rk4_function(NT, solver, fc!, Ts, nx, Nc)
79+
function get_rk4_function(NT, solver, f!, Ts, nx, Nc)
6580
Ts_inner = Ts/solver.supersample
66-
xcur = zeros(NT, nx)
67-
k1 = zeros(NT, nx)
68-
k2 = zeros(NT, nx)
69-
k3 = zeros(NT, nx)
70-
k4 = zeros(NT, nx)
71-
f! = function rk4_solver!(xnext, x, u, d, p)
72-
CT = promote_type(eltype(x), eltype(u), eltype(d))
73-
#=xcur = get_tmp(xcur_cache, CT)
74-
k1 = get_tmp(k1_cache, CT)
75-
k2 = get_tmp(k2_cache, CT)
76-
k3 = get_tmp(k3_cache, CT)
77-
k4 = get_tmp(k4_cache, CT)=#
78-
xterm = xnext
79-
@. xcur = x
81+
function rk4_solver_f!(xnext, K, x, u, d, p)
82+
xcurr = @views K[1:nx]
83+
k1 = @views K[(1nx + 1):(2nx)]
84+
k2 = @views K[(2nx + 1):(3nx)]
85+
k3 = @views K[(3nx + 1):(4nx)]
86+
k4 = @views K[(4nx + 1):(5nx)]
87+
@. xcurr = x
8088
for i=1:solver.supersample
81-
fc!(k1, xcur, u, d, p)
82-
@. xterm = xcur + k1 * Ts_inner/2
83-
fc!(k2, xterm, u, d, p)
84-
@. xterm = xcur + k2 * Ts_inner/2
85-
fc!(k3, xterm, u, d, p)
86-
@. xterm = xcur + k3 * Ts_inner
87-
fc!(k4, xterm, u, d, p)
88-
@. xcur = xcur + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
89+
f!(k1, xcurr, u, d, p)
90+
@. xnext = xcurr + k1 * Ts_inner/2
91+
f!(k2, xnext, u, d, p)
92+
@. xnext = xcurr + k2 * Ts_inner/2
93+
f!(k3, xnext, u, d, p)
94+
@. xnext = xcurr + k3 * Ts_inner
95+
f!(k4, xnext, u, d, p)
96+
@. xcurr = xcurr + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
8997
end
90-
@. xnext = xcur
98+
@. xnext = xcurr
9199
return nothing
92100
end
93-
return f!
101+
return rk4_solver_f!
94102
end
95103

96104
"Get the f! function for the explicit Euler solver."
97105
function get_euler_function(NT, solver, fc!, Ts, nx, Nc)
98106
Ts_inner = Ts/solver.supersample
99-
xcur_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
100-
k_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
101-
f! = function euler_solver!(xnext, x, u, d, p)
107+
function euler_solver_f!(xnext, x, u, d, p)
102108
CT = promote_type(eltype(x), eltype(u), eltype(d))
103109
xcur = get_tmp(xcur_cache, CT)
104110
k = get_tmp(k_cache, CT)
@@ -111,7 +117,7 @@ function get_euler_function(NT, solver, fc!, Ts, nx, Nc)
111117
@. xnext = xcur
112118
return nothing
113119
end
114-
return f!
120+
return euler_solver_f!
115121
end
116122

117123
"""

src/sim_model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct SimModelBuffer{NT<:Real}
2525
x::Vector{NT}
2626
y::Vector{NT}
2727
d::Vector{NT}
28-
K::Matrix{NT}
28+
K::Vector{NT}
2929
empty::Vector{NT}
3030
end
3131

@@ -42,7 +42,7 @@ function SimModelBuffer{NT}(nu::Int, nx::Int, ny::Int, nd::Int, ns::Int=0) where
4242
x = Vector{NT}(undef, nx)
4343
y = Vector{NT}(undef, ny)
4444
d = Vector{NT}(undef, nd)
45-
K = Matrix{NT}(undef, nx, ns)
45+
K = Vector{NT}(undef, nx*(ns+1)) # the "+1" is necessary because of super-sampling
4646
empty = Vector{NT}(undef, 0)
4747
return SimModelBuffer{NT}(u, x, y, d, K, empty)
4848
end

0 commit comments

Comments
 (0)