@@ -3,23 +3,38 @@ abstract type DiffSolver end
33
44" Empty solver for nonlinear discrete-time models."
55struct EmptySolver <: DiffSolver
6- ns:: Int
6+ ns:: Int # number of stages
77 EmptySolver () = new (0 )
88end
99
10- function get_solver_functions (NT:: DataType , :: EmptySolver , f!, h!, _ ... )
11- f_solver! (xnext, _ , x, u, d, p) = f! (xnext, x, u, d, p)
12- return f!, h!
10+ """
11+ get_solver_functions(NT::DataType, solver::EmptySolver, f!, h!, Ts, nu, nx, ny, nd)
12+
13+ Get `solver_f!` and `solver_h!` functions for the `EmptySolver` (discrete models).
14+
15+ The functions should have the following signature:
16+ ```
17+ solver_f!(xnext, K, x, u, d, p) -> nothing
18+ solver_h!(y, x, d, p) -> nothing
19+ ```
20+ in which `xnext`, `K` and `y` arguments should be mutated in-place. The `K` argument is
21+ a vector of `nx*(solver.ns+1)` elements to store the solver intermediary stage values,
22+ and also the current state value when `supersample ≠ 1`.
23+ """
24+ function get_solver_functions (:: DataType , :: EmptySolver , f!, h!, _ , _ , _ , _ )
25+ solver_f! (xnext, _ , x, u, d, p) = f! (xnext, x, u, d, p)
26+ solver_h! = h!
27+ return solver_f!, solver_h!
1328end
1429
1530function Base. show (io:: IO , solver:: EmptySolver )
1631 print (io, " Empty differential equation solver." )
1732end
1833
1934struct RungeKutta <: DiffSolver
20- ns:: Int
21- order:: Int
22- supersample:: Int
35+ ns:: Int # number of stages
36+ order:: Int # order of the method
37+ supersample:: Int # number of internal steps
2338 function RungeKutta (order:: Int , supersample:: Int )
2439 if order ≠ 4 && order ≠ 1
2540 throw (ArgumentError (" only 1st and 4th order Runge-Kutta is supported." ))
@@ -30,7 +45,7 @@ struct RungeKutta <: DiffSolver
3045 if supersample < 1
3146 throw (ArgumentError (" supersample must be greater than 0" ))
3247 end
33- ns = order # only true for order ≤ 4
48+ ns = order # only true for order ≤ 4 with RungeKutta
3449 return new (ns, order, supersample)
3550 end
3651end
@@ -46,59 +61,50 @@ This solver is allocation-free if the `f!` and `h!` functions do not allocate.
4661"""
4762RungeKutta (order:: Int = 4 ; supersample:: Int = 1 ) = RungeKutta (order, supersample)
4863
49- " Get the `f !` and `h !` functions for the explicit Runge-Kutta solvers."
50- function get_solver_functions (NT:: DataType , solver:: RungeKutta , fc !, hc !, Ts, _ , nx, _ , _ )
64+ " Get `solve_f !` and `solver_h !` functions for the explicit Runge-Kutta solvers."
65+ function get_solver_functions (NT:: DataType , solver:: RungeKutta , f !, h !, Ts, _ , nx, _ , _ )
5166 Nc = nx + 2
52- f ! = if solver. order== 4
53- get_rk4_function (NT, solver, fc !, Ts, nx, Nc)
67+ solver_f ! = if solver. order== 4
68+ get_rk4_function (NT, solver, f !, Ts, nx, Nc)
5469 elseif solver. order== 1
55- get_euler_function (NT, solver, fc !, Ts, nx, Nc)
70+ get_euler_function (NT, solver, f !, Ts, nx, Nc)
5671 else
5772 throw (ArgumentError (" only 1st and 4th order Runge-Kutta is supported." ))
5873 end
59- h ! = hc !
60- return f !, h !
74+ solver_h ! = h !
75+ return solver_f !, solver_h !
6176end
6277
6378" Get the f! function for the 4th order explicit Runge-Kutta solver."
64- function get_rk4_function (NT, solver, fc !, Ts, nx, Nc)
79+ function get_rk4_function (NT, solver, f !, Ts, nx, Nc)
6580 Ts_inner = Ts/ solver. supersample
66- xcur = zeros (NT, nx)
67- k1 = zeros (NT, nx)
68- k2 = zeros (NT, nx)
69- k3 = zeros (NT, nx)
70- k4 = zeros (NT, nx)
71- f! = function rk4_solver! (xnext, x, u, d, p)
72- CT = promote_type (eltype (x), eltype (u), eltype (d))
73- #= xcur = get_tmp(xcur_cache, CT)
74- k1 = get_tmp(k1_cache, CT)
75- k2 = get_tmp(k2_cache, CT)
76- k3 = get_tmp(k3_cache, CT)
77- k4 = get_tmp(k4_cache, CT)=#
78- xterm = xnext
79- @. xcur = x
81+ function rk4_solver_f! (xnext, K, x, u, d, p)
82+ xcurr = @views K[1 : nx]
83+ k1 = @views K[(1 nx + 1 ): (2 nx)]
84+ k2 = @views K[(2 nx + 1 ): (3 nx)]
85+ k3 = @views K[(3 nx + 1 ): (4 nx)]
86+ k4 = @views K[(4 nx + 1 ): (5 nx)]
87+ @. xcurr = x
8088 for i= 1 : solver. supersample
81- fc ! (k1, xcur , u, d, p)
82- @. xterm = xcur + k1 * Ts_inner/ 2
83- fc ! (k2, xterm , u, d, p)
84- @. xterm = xcur + k2 * Ts_inner/ 2
85- fc ! (k3, xterm , u, d, p)
86- @. xterm = xcur + k3 * Ts_inner
87- fc ! (k4, xterm , u, d, p)
88- @. xcur = xcur + (k1 + 2 k2 + 2 k3 + k4)* Ts_inner/ 6
89+ f ! (k1, xcurr , u, d, p)
90+ @. xnext = xcurr + k1 * Ts_inner/ 2
91+ f ! (k2, xnext , u, d, p)
92+ @. xnext = xcurr + k2 * Ts_inner/ 2
93+ f ! (k3, xnext , u, d, p)
94+ @. xnext = xcurr + k3 * Ts_inner
95+ f ! (k4, xnext , u, d, p)
96+ @. xcurr = xcurr + (k1 + 2 k2 + 2 k3 + k4)* Ts_inner/ 6
8997 end
90- @. xnext = xcur
98+ @. xnext = xcurr
9199 return nothing
92100 end
93- return f !
101+ return rk4_solver_f !
94102end
95103
96104" Get the f! function for the explicit Euler solver."
97105function get_euler_function (NT, solver, fc!, Ts, nx, Nc)
98106 Ts_inner = Ts/ solver. supersample
99- xcur_cache:: DiffCache{Vector{NT}, Vector{NT}} = DiffCache (zeros (NT, nx), Nc)
100- k_cache:: DiffCache{Vector{NT}, Vector{NT}} = DiffCache (zeros (NT, nx), Nc)
101- f! = function euler_solver! (xnext, x, u, d, p)
107+ function euler_solver_f! (xnext, x, u, d, p)
102108 CT = promote_type (eltype (x), eltype (u), eltype (d))
103109 xcur = get_tmp (xcur_cache, CT)
104110 k = get_tmp (k_cache, CT)
@@ -111,7 +117,7 @@ function get_euler_function(NT, solver, fc!, Ts, nx, Nc)
111117 @. xnext = xcur
112118 return nothing
113119 end
114- return f !
120+ return euler_solver_f !
115121end
116122
117123"""
0 commit comments