Skip to content

Commit 5619b34

Browse files
committed
added: ForwardEuler ODE solve
alias for `RungeKutta(1; supersample)`
1 parent 48a7c31 commit 5619b34

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

docs/src/public/sim_model.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@ ModelPredictiveControl.DiffSolver
6666

6767
```@docs
6868
RungeKutta
69+
ForwardEuler
6970
```

src/ModelPredictiveControl.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import PreallocationTools: DiffCache, get_tmp
2323
import OSQP, Ipopt
2424

2525
export SimModel, LinModel, NonLinModel
26-
export DiffSolver, RungeKutta
26+
export DiffSolver, RungeKutta, ForwardEuler
2727
export setop!, setname!
2828
export setstate!, setmodel!, preparestate!, updatestate!, evaloutput, linearize, linearize!
2929
export savetime!, periodsleep

src/model/solver.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ allocation-free if the `f!` and `h!` functions do not allocate.
3737
"""
3838
RungeKutta(order::Int=4; supersample::Int=1) = RungeKutta(order, supersample)
3939

40-
"Get the `f!` and `h!` functions for Runge-Kutta solver."
40+
"Get the `f!` and `h!` functions for the explicit Runge-Kutta solvers."
4141
function get_solver_functions(NT::DataType, solver::RungeKutta, fc!, hc!, Ts, _ , nx, _ , _ )
4242
order = solver.order
4343
Ts_inner = Ts/solver.supersample
@@ -47,7 +47,21 @@ function get_solver_functions(NT::DataType, solver::RungeKutta, fc!, hc!, Ts, _
4747
k2_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4848
k3_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4949
k4_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
50-
if order==4
50+
if order==1
51+
f! = function euler_solver!(xnext, x, u, d, p)
52+
CT = promote_type(eltype(x), eltype(u), eltype(d))
53+
xcur = get_tmp(xcur_cache, CT)
54+
k1 = get_tmp(k1_cache, CT)
55+
xterm = xnext
56+
@. xcur = x
57+
for i=1:solver.supersample
58+
fc!(k1, xcur, u, d, p)
59+
@. xcur = xcur + k1 * Ts_inner
60+
end
61+
@. xnext = xcur
62+
return nothing
63+
end
64+
elseif order==4
5165
f! = function rk4_solver!(xnext, x, u, d, p)
5266
CT = promote_type(eltype(x), eltype(u), eltype(d))
5367
xcur = get_tmp(xcur_cache, CT)
@@ -75,6 +89,15 @@ function get_solver_functions(NT::DataType, solver::RungeKutta, fc!, hc!, Ts, _
7589
return f!, h!
7690
end
7791

92+
"""
93+
ForwardEuler(; supersample=1)
94+
95+
Create a Forward Euler solver with optional super-sampling.
96+
97+
This is an alias for `RungeKutta(1; supersample)`.
98+
"""
99+
const ForwardEuler(;supersample=1) = RungeKutta(1; supersample)
100+
78101
function Base.show(io::IO, solver::RungeKutta)
79102
N, n = solver.order, solver.supersample
80103
print(io, "$(N)th order Runge-Kutta differential equation solver with $n supersamples.")

0 commit comments

Comments
 (0)