Skip to content

Commit 2ea5887

Browse files
committed
wip: simplifying allocations in DiffSolver
1 parent 7d8ccde commit 2ea5887

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

src/model/solver.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,22 @@
22
abstract type DiffSolver end
33

44
"Empty solver for nonlinear discrete-time models."
5-
struct EmptySolver <: DiffSolver end
6-
get_solver_functions(::DataType, ::EmptySolver, f!, h!, _ ... ) = f!, h!
5+
struct EmptySolver <: DiffSolver
6+
ns::Int
7+
EmptySolver() = new(0)
8+
end
9+
10+
function get_solver_functions(NT::DataType, ::EmptySolver, f!, h!, _ ... )
11+
f_solver!(xnext, _ , x, u, d, p) = f!(xnext, x, u, d, p)
12+
return f!, h!
13+
end
714

815
function Base.show(io::IO, solver::EmptySolver)
916
print(io, "Empty differential equation solver.")
1017
end
1118

1219
struct RungeKutta <: DiffSolver
20+
ns::Int
1321
order::Int
1422
supersample::Int
1523
function RungeKutta(order::Int, supersample::Int)
@@ -22,7 +30,8 @@ struct RungeKutta <: DiffSolver
2230
if supersample < 1
2331
throw(ArgumentError("supersample must be greater than 0"))
2432
end
25-
return new(order, supersample)
33+
ns = order # only true for order ≤ 4
34+
return new(ns, order, supersample)
2635
end
2736
end
2837

@@ -54,18 +63,18 @@ end
5463
"Get the f! function for the 4th order explicit Runge-Kutta solver."
5564
function get_rk4_function(NT, solver, fc!, Ts, nx, Nc)
5665
Ts_inner = Ts/solver.supersample
57-
xcur_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
58-
k1_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
59-
k2_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
60-
k3_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
61-
k4_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
66+
xcur = zeros(NT, nx)
67+
k1 = zeros(NT, nx)
68+
k2 = zeros(NT, nx)
69+
k3 = zeros(NT, nx)
70+
k4 = zeros(NT, nx)
6271
f! = function rk4_solver!(xnext, x, u, d, p)
6372
CT = promote_type(eltype(x), eltype(u), eltype(d))
64-
xcur = get_tmp(xcur_cache, CT)
73+
#=xcur = get_tmp(xcur_cache, CT)
6574
k1 = get_tmp(k1_cache, CT)
6675
k2 = get_tmp(k2_cache, CT)
6776
k3 = get_tmp(k3_cache, CT)
68-
k4 = get_tmp(k4_cache, CT)
77+
k4 = get_tmp(k4_cache, CT)=#
6978
xterm = xnext
7079
@. xcur = x
7180
for i=1:solver.supersample

src/sim_model.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,26 @@ struct SimModelBuffer{NT<:Real}
2525
x::Vector{NT}
2626
y::Vector{NT}
2727
d::Vector{NT}
28+
K::Matrix{NT}
2829
empty::Vector{NT}
2930
end
3031

3132
@doc raw"""
32-
SimModelBuffer{NT}(nu::Int, nx::Int, ny::Int, nd::Int, linearization=nothing)
33+
SimModelBuffer{NT}(nu::Int, nx::Int, ny::Int, nd::Int, ns::Int=0)
3334
3435
Create a buffer for `SimModel` objects for inputs, states, outputs, and disturbances.
3536
36-
The buffer is used to store intermediate results during simulation without allocating.
37+
The buffer is used to store intermediate results during simulation without allocating. The
38+
argument `ns` is the number of stage of the [`DiffSolver`](@ref), when applicable.
3739
"""
38-
function SimModelBuffer{NT}(nu::Int, nx::Int, ny::Int, nd::Int, ) where {NT<:Real}
40+
function SimModelBuffer{NT}(nu::Int, nx::Int, ny::Int, nd::Int, ns::Int=0) where {NT<:Real}
3941
u = Vector{NT}(undef, nu)
4042
x = Vector{NT}(undef, nx)
4143
y = Vector{NT}(undef, ny)
4244
d = Vector{NT}(undef, nd)
45+
K = Matrix{NT}(undef, nx, ns)
4346
empty = Vector{NT}(undef, 0)
44-
return SimModelBuffer{NT}(u, x, y, d, empty)
47+
return SimModelBuffer{NT}(u, x, y, d, K, empty)
4548
end
4649

4750

0 commit comments

Comments
 (0)