Skip to content

Commit b1bc160

Browse files
committed
refactor: start systems test file
1 parent 4ac62b4 commit b1bc160

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
module MTKCasADiDynamicOptExt
2+
using ModelingToolkit
3+
using CasADi
4+
using DiffEqDevTools, DiffEqBase
5+
using DataInterpolations
6+
const MTK = MOdelingToolkit
7+
8+
struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
9+
AbstractDynamicOptProblem{uType, tType, isinplace}
10+
f::F
11+
u0::uType
12+
tspan::tType
13+
p::P
14+
model::Opti
15+
kwargs::K
16+
17+
function CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
18+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
19+
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
20+
end
21+
end
22+
23+
struct CasADiModel
24+
opti::Opti
25+
U::MX
26+
V::MX
27+
end
28+
29+
struct TimedMX
30+
end
31+
32+
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
33+
dt = nothing,
34+
steps = nothing,
35+
guesses = Dict(), kwargs...)
36+
MTK.warn_overdetermined(sys, u0map)
37+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
38+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
39+
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
40+
41+
pmap = Dict{Any, Any}(pmap)
42+
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
43+
model = init_model()
44+
end
45+
46+
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
47+
ctrls = MTK.unbound_inputs(sys)
48+
states = unknowns(sys)
49+
model = CasADi.Opti()
50+
51+
U = CasADi.variable!(model, length(states), steps)
52+
V = CasADi.variable!(model, length(ctrls), steps)
53+
end
54+
55+
function add_initial_constraints!()
56+
57+
end
58+
59+
function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
60+
61+
end
62+
63+
function add_cost_function!(model)
64+
65+
end
66+
67+
function add_solve_constraints!(prob, tableau; is_free_t)
68+
A = tableau.A
69+
α = tableau.α
70+
c = tableau.c
71+
model = prob.model
72+
f = prob.f
73+
p = prob.p
74+
75+
opti = model.opti
76+
t = model[:t]
77+
tsteps = supports(t)
78+
tmax = tsteps[end]
79+
pop!(tsteps)
80+
tₛ = is_free_t ? model[:tf] : 1
81+
dt = tsteps[2] - tsteps[1]
82+
83+
U = model.U
84+
V = model.V
85+
nᵤ = length(U)
86+
nᵥ = length(V)
87+
88+
if is_explicit(tableau)
89+
K = Any[]
90+
for τ in tsteps
91+
for (i, h) in enumerate(c)
92+
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
93+
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
94+
Vₙ = [V[i](τ) for i in 1:nᵥ]
95+
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
96+
push!(K, Kₙ)
97+
end
98+
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
99+
subject_to!(model, U[n](τ) + ΔU[n]==U[n](τ + dt))
100+
empty!(K)
101+
end
102+
else
103+
end
104+
end
105+
106+
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol}, ode_solver::Union{String, Symbol}; silent = false)
107+
model = prob.model
108+
opti = model.opti
109+
110+
solver!(opti, solver)
111+
sol = solve(opti)
112+
DynamicOptSolution(model, sol, input_sol)
113+
end
114+
115+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function lotkavolterra()
2+
3+
end
4+
5+
function ()
6+
7+
end
8+
9+
function rocket_fft()
10+
11+
end
12+
13+
function rocket()
14+
15+
end
16+
17+
function cartpole()
18+
19+
end

0 commit comments

Comments
 (0)