|
1 | 1 | abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <: |
2 | 2 | SciMLBase.AbstractODEProblem{uType, tType, isinplace} end |
3 | 3 |
|
| 4 | +abstract type AbstractCollocation end |
| 5 | + |
4 | 6 | struct DynamicOptSolution |
5 | 7 | model::Any |
6 | 8 | sol::ODESolution |
@@ -181,3 +183,148 @@ function process_tspan(tspan, dt, steps) |
181 | 183 | return length(tspan[1]:dt:tspan[2]), false |
182 | 184 | end |
183 | 185 | end |
| 186 | + |
| 187 | +function process_DynamicOptProblem(prob_type::AbstractDynamicOptProblem, model_type, sys::ODESystem, u0map, tspan, pmap; |
| 188 | + dt = nothing, |
| 189 | + steps = nothing, |
| 190 | + guesses = Dict(), kwargs...) |
| 191 | + |
| 192 | + MTK.warn_overdetermined(sys, u0map) |
| 193 | + _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses)) |
| 194 | + f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap; |
| 195 | + t = tspan !== nothing ? tspan[1] : tspan, kwargs...) |
| 196 | + |
| 197 | + stidxmap = Dict([v => i for (i, v) in enumerate(states)]) |
| 198 | + u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map]) |
| 199 | + u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : |
| 200 | + [stidxmap[MTK.default_toterm(k)] for (k, v) in u0map] |
| 201 | + pmap = Dict{Any, Any}(pmap) |
| 202 | + steps, is_free_t = MTK.process_tspan(tspan, dt, steps) |
| 203 | + |
| 204 | + ctrls = MTK.unbound_inputs(sys) |
| 205 | + states = unknowns(sys) |
| 206 | + |
| 207 | + model = generate_internal_model(model_type) |
| 208 | + U = generate_U(model, u0) |
| 209 | + V = generate_V() |
| 210 | + tₛ = generate_timescale() |
| 211 | + fullmodel = model_type(model, U, V, tₛ) |
| 212 | + |
| 213 | + set_variable_bounds!(fullmodel, sys, pmap) |
| 214 | + add_cost_function!(fullmodel, sys, tspan, pmap; is_free_t) |
| 215 | + add_user_constraints!(fullmodel, sys, tspan, pmap; is_free_t) |
| 216 | + add_initial_constraints!(fullmodel, u0, u0_idxs) |
| 217 | + |
| 218 | + prob_type(f, u0, tspan, p, fullmodel, kwargs...) |
| 219 | +end |
| 220 | + |
| 221 | +function add_cost_function!() |
| 222 | + jcosts = copy(MTK.get_costs(sys)) |
| 223 | + consolidate = MTK.get_consolidate(sys) |
| 224 | + if isnothing(jcosts) || isempty(jcosts) |
| 225 | + minimize!(opti, MX(0)) |
| 226 | + return |
| 227 | + end |
| 228 | + |
| 229 | + jcosts = substitute_model_vars(model, sys, pmap, jcosts; is_free_t) |
| 230 | + jcosts = substitute_free_final_vars(model, sys, pmap, jcosts; is_free_t) |
| 231 | + jcosts = substitute_fixed_t_vars(model, sys, pmap, jcosts; is_free_t) |
| 232 | + jcosts = substitute_integral() |
| 233 | +end |
| 234 | + |
| 235 | +function add_user_constraints!() |
| 236 | + conssys = MTK.get_constraintsystem(sys) |
| 237 | + jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys) |
| 238 | + (isnothing(jconstraints) || isempty(jconstraints)) && return nothing |
| 239 | + |
| 240 | + auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)]) |
| 241 | + jconstraints = substitute_model_vars(model, sys, pmap, jconstraints; auxmap, is_free_t) |
| 242 | + |
| 243 | + for c in jconstraints |
| 244 | + if cons isa Equation |
| 245 | + add_constraint!() |
| 246 | + elseif cons.relational_op === Symbolics.geq |
| 247 | + add_constraint!() |
| 248 | + else |
| 249 | + add_constraint!() |
| 250 | + end |
| 251 | + end |
| 252 | +end |
| 253 | + |
| 254 | +function generate_U end |
| 255 | +function generate_V end |
| 256 | +function generate_timescale end |
| 257 | + |
| 258 | +function add_initial_constraints! end |
| 259 | +function add_constraint! end |
| 260 | + |
| 261 | +function add_collocation_solve_constraints!(prob, tableau) |
| 262 | + nᵤ = size(U.u, 1) |
| 263 | + nᵥ = size(V.u, 1) |
| 264 | + |
| 265 | + if is_explicit(tableau) |
| 266 | + K = MX[] |
| 267 | + for k in 1:(length(tsteps) - 1) |
| 268 | + τ = tsteps[k] |
| 269 | + for (i, h) in enumerate(c) |
| 270 | + ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ))) |
| 271 | + Uₙ = U.u[:, k] + ΔU * dt |
| 272 | + Vₙ = V.u[:, k] |
| 273 | + Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time |
| 274 | + push!(K, Kₙ) |
| 275 | + end |
| 276 | + ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)]) |
| 277 | + subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1]) |
| 278 | + empty!(K) |
| 279 | + end |
| 280 | + else |
| 281 | + for k in 1:(length(tsteps) - 1) |
| 282 | + τ = tsteps[k] |
| 283 | + # Kᵢ = generate_K() |
| 284 | + Kᵢ = variable!(solver_opti, nᵤ, length(α)) |
| 285 | + ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method |
| 286 | + for (i, h) in enumerate(c) |
| 287 | + ΔU = ΔUs[i, :]' |
| 288 | + Uₙ = U.u[:, k] + ΔU * dt |
| 289 | + Vₙ = V.u[:, k] |
| 290 | + subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt)) |
| 291 | + end |
| 292 | + ΔU_tot = dt * (Kᵢ * α) |
| 293 | + subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1]) |
| 294 | + end |
| 295 | + end |
| 296 | +end |
| 297 | + |
| 298 | +function add_equational_solve_constraints!() |
| 299 | + diff_eqs = substitute_differentials() |
| 300 | + add_constraint!() |
| 301 | + |
| 302 | + alg_eqs = substitute_model_vars() |
| 303 | + add_constraint!() |
| 304 | +end |
| 305 | + |
| 306 | +""" |
| 307 | +Add the solve constraints, set the solver (Ipopt, e.g.) |
| 308 | +""" |
| 309 | +function prepare_solver end |
| 310 | + |
| 311 | +function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation) |
| 312 | + #add_solve_constraints!(prob, solver) |
| 313 | + solver = prepare_solver(prob, solver) |
| 314 | + sol = solve_prob(prob, solver) |
| 315 | + |
| 316 | + ts = get_t_values(sol) |
| 317 | + Us = get_U_values(sol) |
| 318 | + Vs = get_V_values(sol) |
| 319 | + |
| 320 | + ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us) |
| 321 | + input_sol = DiffEqBase.build_solution(prob, solver, ts, Vs) |
| 322 | + |
| 323 | + if successful_solve(model) |
| 324 | + ode_sol = SciMLBase.solution_new_retcode( |
| 325 | + ode_sol, SciMLBase.ReturnCode.ConvergenceFailure) |
| 326 | + !isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode( |
| 327 | + input_sol, SciMLBase.ReturnCode.ConvergenceFailure)) |
| 328 | + end |
| 329 | + DynamicOptSolution(model, ode_sol, input_sol) |
| 330 | +end |
0 commit comments