|
1 | 1 | abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
|
2 | 2 | SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
|
3 | 3 |
|
| 4 | +abstract type AbstractCollocation end |
| 5 | + |
4 | 6 | struct DynamicOptSolution
|
5 | 7 | model::Any
|
6 | 8 | sol::ODESolution
|
@@ -148,3 +150,148 @@ function process_tspan(tspan, dt, steps)
|
148 | 150 | return length(tspan[1]:dt:tspan[2]), false
|
149 | 151 | end
|
150 | 152 | end
|
| 153 | + |
| 154 | +function process_DynamicOptProblem(prob_type::AbstractDynamicOptProblem, model_type, sys::ODESystem, u0map, tspan, pmap; |
| 155 | + dt = nothing, |
| 156 | + steps = nothing, |
| 157 | + guesses = Dict(), kwargs...) |
| 158 | + |
| 159 | + MTK.warn_overdetermined(sys, u0map) |
| 160 | + _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses)) |
| 161 | + f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap; |
| 162 | + t = tspan !== nothing ? tspan[1] : tspan, kwargs...) |
| 163 | + |
| 164 | + stidxmap = Dict([v => i for (i, v) in enumerate(states)]) |
| 165 | + u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map]) |
| 166 | + u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : |
| 167 | + [stidxmap[MTK.default_toterm(k)] for (k, v) in u0map] |
| 168 | + pmap = Dict{Any, Any}(pmap) |
| 169 | + steps, is_free_t = MTK.process_tspan(tspan, dt, steps) |
| 170 | + |
| 171 | + ctrls = MTK.unbound_inputs(sys) |
| 172 | + states = unknowns(sys) |
| 173 | + |
| 174 | + model = generate_internal_model(model_type) |
| 175 | + U = generate_U(model, u0) |
| 176 | + V = generate_V() |
| 177 | + tₛ = generate_timescale() |
| 178 | + fullmodel = model_type(model, U, V, tₛ) |
| 179 | + |
| 180 | + set_variable_bounds!(fullmodel, sys, pmap) |
| 181 | + add_cost_function!(fullmodel, sys, tspan, pmap; is_free_t) |
| 182 | + add_user_constraints!(fullmodel, sys, tspan, pmap; is_free_t) |
| 183 | + add_initial_constraints!(fullmodel, u0, u0_idxs) |
| 184 | + |
| 185 | + prob_type(f, u0, tspan, p, fullmodel, kwargs...) |
| 186 | +end |
| 187 | + |
| 188 | +function add_cost_function!() |
| 189 | + jcosts = copy(MTK.get_costs(sys)) |
| 190 | + consolidate = MTK.get_consolidate(sys) |
| 191 | + if isnothing(jcosts) || isempty(jcosts) |
| 192 | + minimize!(opti, MX(0)) |
| 193 | + return |
| 194 | + end |
| 195 | + |
| 196 | + jcosts = substitute_model_vars(model, sys, pmap, jcosts; is_free_t) |
| 197 | + jcosts = substitute_free_final_vars(model, sys, pmap, jcosts; is_free_t) |
| 198 | + jcosts = substitute_fixed_t_vars(model, sys, pmap, jcosts; is_free_t) |
| 199 | + jcosts = substitute_integral() |
| 200 | +end |
| 201 | + |
| 202 | +function add_user_constraints!() |
| 203 | + conssys = MTK.get_constraintsystem(sys) |
| 204 | + jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys) |
| 205 | + (isnothing(jconstraints) || isempty(jconstraints)) && return nothing |
| 206 | + |
| 207 | + auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)]) |
| 208 | + jconstraints = substitute_model_vars(model, sys, pmap, jconstraints; auxmap, is_free_t) |
| 209 | + |
| 210 | + for c in jconstraints |
| 211 | + if cons isa Equation |
| 212 | + add_constraint!() |
| 213 | + elseif cons.relational_op === Symbolics.geq |
| 214 | + add_constraint!() |
| 215 | + else |
| 216 | + add_constraint!() |
| 217 | + end |
| 218 | + end |
| 219 | +end |
| 220 | + |
| 221 | +function generate_U end |
| 222 | +function generate_V end |
| 223 | +function generate_timescale end |
| 224 | + |
| 225 | +function add_initial_constraints! end |
| 226 | +function add_constraint! end |
| 227 | + |
| 228 | +function add_collocation_solve_constraints!(prob, tableau) |
| 229 | + nᵤ = size(U.u, 1) |
| 230 | + nᵥ = size(V.u, 1) |
| 231 | + |
| 232 | + if is_explicit(tableau) |
| 233 | + K = MX[] |
| 234 | + for k in 1:(length(tsteps) - 1) |
| 235 | + τ = tsteps[k] |
| 236 | + for (i, h) in enumerate(c) |
| 237 | + ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ))) |
| 238 | + Uₙ = U.u[:, k] + ΔU * dt |
| 239 | + Vₙ = V.u[:, k] |
| 240 | + Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time |
| 241 | + push!(K, Kₙ) |
| 242 | + end |
| 243 | + ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)]) |
| 244 | + subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1]) |
| 245 | + empty!(K) |
| 246 | + end |
| 247 | + else |
| 248 | + for k in 1:(length(tsteps) - 1) |
| 249 | + τ = tsteps[k] |
| 250 | + # Kᵢ = generate_K() |
| 251 | + Kᵢ = variable!(solver_opti, nᵤ, length(α)) |
| 252 | + ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method |
| 253 | + for (i, h) in enumerate(c) |
| 254 | + ΔU = ΔUs[i, :]' |
| 255 | + Uₙ = U.u[:, k] + ΔU * dt |
| 256 | + Vₙ = V.u[:, k] |
| 257 | + subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt)) |
| 258 | + end |
| 259 | + ΔU_tot = dt * (Kᵢ * α) |
| 260 | + subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1]) |
| 261 | + end |
| 262 | + end |
| 263 | +end |
| 264 | + |
| 265 | +function add_equational_solve_constraints!() |
| 266 | + diff_eqs = substitute_differentials() |
| 267 | + add_constraint!() |
| 268 | + |
| 269 | + alg_eqs = substitute_model_vars() |
| 270 | + add_constraint!() |
| 271 | +end |
| 272 | + |
| 273 | +""" |
| 274 | +Add the solve constraints, set the solver (Ipopt, e.g.) |
| 275 | +""" |
| 276 | +function prepare_solver end |
| 277 | + |
| 278 | +function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation) |
| 279 | + #add_solve_constraints!(prob, solver) |
| 280 | + solver = prepare_solver(prob, solver) |
| 281 | + sol = solve_prob(prob, solver) |
| 282 | + |
| 283 | + ts = get_t_values(sol) |
| 284 | + Us = get_U_values(sol) |
| 285 | + Vs = get_V_values(sol) |
| 286 | + |
| 287 | + ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us) |
| 288 | + input_sol = DiffEqBase.build_solution(prob, solver, ts, Vs) |
| 289 | + |
| 290 | + if successful_solve(model) |
| 291 | + ode_sol = SciMLBase.solution_new_retcode( |
| 292 | + ode_sol, SciMLBase.ReturnCode.ConvergenceFailure) |
| 293 | + !isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode( |
| 294 | + input_sol, SciMLBase.ReturnCode.ConvergenceFailure)) |
| 295 | + end |
| 296 | + DynamicOptSolution(model, ode_sol, input_sol) |
| 297 | +end |
0 commit comments