Skip to content

Commit 3c591c0

Browse files
committed
feat: PyomoDynamicOtpPRoblem
1 parent a275b33 commit 3c591c0

File tree

4 files changed

+47
-40
lines changed

4 files changed

+47
-40
lines changed

ext/MTKInfiniteOptExt.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ function MTK.prepare_and_optimize!(prob::JuMPDynamicOptProblem, solver::JuMPColl
186186
add_solve_constraints!(prob, solver.tableau)
187187
set_optimizer(model, solver.solver)
188188
optimize!(model)
189+
model
189190
end
190191

191192
function MTK.prepare_and_optimize!(prob::InfiniteOptDynamicOptProblem, solver::InfiniteOptCollocation; verbose = false, kwargs...)
@@ -194,26 +195,26 @@ function MTK.prepare_and_optimize!(prob::InfiniteOptDynamicOptProblem, solver::I
194195
set_derivative_method(model[:t], solver.derivative_method)
195196
set_optimizer(model, solver.solver)
196197
optimize!(model)
198+
model
197199
end
198200

199-
function MTK.get_V_values(m::InfiniteOptModel)
200-
nt = length(supports(m.model[:t]))
201-
if !isempty(m.V)
202-
V_vals = value.(m.V)
201+
function MTK.get_V_values(m::InfiniteModel)
202+
nt = length(supports(m[:t]))
203+
if !isempty(m[:V])
204+
V_vals = value.(m[:V])
203205
V_vals = [[V_vals[i][j] for i in 1:length(V_vals)] for j in 1:nt]
204206
else
205207
nothing
206208
end
207209
end
208-
function MTK.get_U_values(m::InfiniteOptModel)
209-
nt = length(supports(m.model[:t]))
210-
U_vals = value.(m.U)
210+
function MTK.get_U_values(m::InfiniteModel)
211+
nt = length(supports(m[:t]))
212+
U_vals = value.(m[:U])
211213
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:nt]
212214
end
213-
MTK.get_t_values(model) = value(model.tₛ) * supports(model.model[:t])
215+
MTK.get_t_values(m::InfiniteModel) = value(m[:tₛ]) * supports(m[:t])
214216

215-
function MTK.successful_solve(m::InfiniteOptModel)
216-
model = m.model
217+
function MTK.successful_solve(model::InfiniteModel)
217218
tstatus = termination_status(model)
218219
pstatus = primal_status(model)
219220
!has_values(model) &&

ext/MTKPyomoDynamicOptExt.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@ using Pyomo
44
using DiffEqBase
55
using UnPack
66
using NaNMath
7+
using Setfield
78
const MTK = ModelingToolkit
89

910
struct PyomoDynamicOptModel
1011
model::ConcreteModel
1112
U::PyomoVar
1213
V::PyomoVar
13-
tₛ::Union{Int, PyomoVar}
14+
tₛ::PyomoVar
1415
is_free_final::Bool
16+
solver_model::Union{Nothing, ConcreteModel}
1517
dU::PyomoVar
1618
model_sym::Union{Num, Symbolics.BasicSymbolic}
1719
t_sym::Union{Num, Symbolics.BasicSymbolic}
18-
idx_sym::Union{Num, Symbolics.BasicSymbolic}
20+
uidx_sym::Union{Num, Symbolics.BasicSymbolic}
21+
vidx_sym::Union{Num, Symbolics.BasicSymbolic}
1922

2023
function PyomoDynamicOptModel(model, U, V, tₛ, is_free_final)
21-
@variables MODEL_SYM::Symbolics.symstruct(PyomoDynamicOptModel) IDX_SYM::Int T_SYM
24+
@variables MODEL_SYM::Symbolics.symstruct(ConcreteModel) U_IDX_SYM::Int V_IDX_SYM::Int T_SYM
2225
model.dU = dae.DerivativeVar(U, wrt = model.t, initialize = 0)
23-
new(model, U, V, tₛ, is_free_final, PyomoVar(model.dU), MODEL_SYM, T_SYM, IDX_SYM)
26+
new(model, U, V, tₛ, is_free_final, nothing, PyomoVar(model.dU), MODEL_SYM, T_SYM, U_IDX_SYM, V_IDX_SYM)
2427
end
2528
end
2629

@@ -39,6 +42,9 @@ struct PyomoDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
3942
end
4043
end
4144

45+
pysym_getproperty(s, name::Symbol) = Symbolics.wrap(SymbolicUtils.term(_getproperty, s, Val{name}(), type = Symbolics.Struct{PyomoVar}))
46+
_getproperty(s, name::Val{fieldname}) where fieldname = getproperty(s, fieldname)
47+
4248
function MTK.PyomoDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
4349
dt = nothing, steps = nothing,
4450
guesses = Dict(), kwargs...)
@@ -68,24 +74,22 @@ function MTK.generate_input_variable!(m::ConcreteModel, c0, nc, ts)
6874
end
6975

7076
function MTK.generate_timescale!(m::ConcreteModel, guess, is_free_t)
71-
m.tₛ = is_free_t ? PyomoVar(pyomo.Var(initialize = guess, bounds = (0, Inf))) : 1
77+
m.tₛ = is_free_t ? PyomoVar(pyomo.Var(initialize = guess, bounds = (0, Inf))) : PyomoVar(Pyomo.Py(1))
7278
end
7379

74-
function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons)
75-
@unpack model, model_sym, idx_sym, t_sym = pmodel
76-
@show model.dU
80+
function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons; n_idxs = 1)
81+
@unpack model, model_sym, t_sym = pmodel
7782
expr = if cons isa Equation
7883
cons.lhs - cons.rhs == 0
7984
elseif cons.relational_op === Symbolics.geq
8085
cons.lhs - cons.rhs 0
8186
else
8287
cons.lhs - cons.rhs 0
8388
end
84-
constraint_f = Symbolics.build_function(expr, model_sym, idx_sym, t_sym, expression = Val{false})
85-
@show typeof(constraint_f)
86-
@show typeof(Pyomo.pyfunc(constraint_f))
87-
cons_sym = gensym()
88-
setproperty!(model, cons_sym, pyomo.Constraint(model.u_idxs, model.t, rule = Pyomo.pyfunc(constraint_f)))
89+
f_expr = Symbolics.build_function(expr, model_sym, t_sym)
90+
cons_sym = Symbol("cons", hash(cons))
91+
constraint_f = eval(:(cons_sym = $f_expr))
92+
setproperty!(model, cons_sym, pyomo.Constraint(model.t, rule = Pyomo.pyfunc(constraint_f)))
8993
end
9094

9195
function MTK.set_objective!(m::PyomoDynamicOptModel, expr)
@@ -107,12 +111,12 @@ end
107111
MTK.process_integral_bounds(model, integral_span, tspan) = integral_span
108112

109113
function MTK.lowered_derivative(m::PyomoDynamicOptModel, i)
110-
mdU = Symbolics.symbolic_getproperty(m.model_sym, :dU).val
114+
mdU = Symbolics.value(pysym_getproperty(m.model_sym, :dU))
111115
Symbolics.unwrap(mdU[i, m.t_sym])
112116
end
113117

114118
function MTK.lowered_var(m::PyomoDynamicOptModel, uv, i, t)
115-
X = Symbolics.symbolic_getproperty(m.model_sym, uv).val
119+
X = Symbolics.value(pysym_getproperty(m.model_sym, uv))
116120
var = t isa Union{Num, Symbolics.Symbolic} ? X[i, m.t_sym] : X[i, t]
117121
Symbolics.unwrap(var)
118122
end
@@ -125,21 +129,21 @@ end
125129
MTK.PyomoCollocation(solver, derivative_method = LagrangeRadau(5)) = PyomoCollocation(solver, derivative_method)
126130

127131
function MTK.prepare_and_optimize!(prob::PyomoDynamicOptProblem, collocation; verbose, kwargs...)
128-
m = prob.wrapped_model.model
132+
solver_m = prob.wrapped_model.model.clone()
129133
dm = collocation.derivative_method
130134
discretizer = TransformationFactory(dm)
131135
ncp = Pyomo.is_finite_difference(dm) ? 1 : dm.np
132-
discretizer.apply_to(m, wrt = m.t, nfe = m.steps, scheme = Pyomo.scheme_string(dm))
136+
discretizer.apply_to(solver_m, wrt = solver_m.t, nfe = solver_m.steps, scheme = Pyomo.scheme_string(dm))
133137
solver = SolverFactory(string(collocation.solver))
134-
solver.solve(m, tee = true)
135-
Main.xx[] = solver
138+
solver.solve(solver_m, tee = true)
139+
solver_m
136140
end
137141

138-
MTK.get_U_values(m::PyomoDynamicOptModel) = [pyomo.value(m.model.U[i]) for i in m.model.U.index_set()]
139-
MTK.get_V_values(m::PyomoDynamicOptModel) = [pyomo.value(m.model.V[i]) for i in m.model.V.index_set()]
140-
MTK.get_t_values(m::PyomoDynamicOptModel) = Pyomo.get_results(m.model, :t)
142+
MTK.get_U_values(m::ConcreteModel) = [[pyomo.value(m.U[i, t]) for i in m.u_idxs] for t in m.t]
143+
MTK.get_V_values(m::ConcreteModel) = [[pyomo.value(m.V[i, t]) for i in m.v_idxs] for t in m.t]
144+
MTK.get_t_values(m::ConcreteModel) = [t for t in m.t]
141145

142-
function MTK.successful_solve(m::PyomoDynamicOptModel)
146+
function MTK.successful_solve(m::ConcreteModel)
143147
ss = m.solver.status
144148
tc = m.solver.termination_condition
145149
if ss == opt.SolverStatus.ok && (tc == opt.TerminationStatus.optimal || tc == opt.TerminationStatus.locallyOptimal)

src/systems/optimal_control_interface.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ function add_equational_constraints!(model, sys, pmap, tspan)
395395
diff_eqs = substitute_params(pmap, diff_eqs)
396396
diff_eqs = substitute_differentials(model, sys, diff_eqs)
397397
for eq in diff_eqs
398+
@show typeof(eq.lhs)
398399
add_constraint!(model, eq.lhs ~ eq.rhs * model.tₛ)
399400
end
400401

@@ -411,7 +412,7 @@ function substitute_differentials(model, sys, eqs)
411412
t = get_iv(sys)
412413
D = Differential(t)
413414
diffsubmap = Dict([D(lowered_var(model, :U, i, t)) => lowered_derivative(model, i) for i in 1:length(unknowns(sys))])
414-
map(c -> Symbolics.substitute(c, diffsubmap), eqs)
415+
eqs = map(c -> Symbolics.substitute(c, diffsubmap), eqs)
415416
end
416417

417418
function substitute_toterm(vars, exprs)
@@ -451,21 +452,21 @@ function successful_solve end
451452
- kwargs are used for other options. For example, the `plugin_options` and `solver_options` will propagated to the Opti object in CasADi.
452453
"""
453454
function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation; verbose = false, kwargs...)
454-
prepare_and_optimize!(prob, solver; verbose, kwargs...)
455+
solved_model = prepare_and_optimize!(prob, solver; verbose, kwargs...)
455456

456-
ts = get_t_values(prob.wrapped_model)
457-
Us = get_U_values(prob.wrapped_model)
458-
Vs = get_V_values(prob.wrapped_model)
457+
ts = get_t_values(solved_model)
458+
Us = get_U_values(solved_model)
459+
Vs = get_V_values(solved_model)
459460
is_free_final(prob.wrapped_model) && (ts .+ prob.tspan[1])
460461

461462
ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us)
462463
input_sol = isnothing(Vs) ? nothing : DiffEqBase.build_solution(prob, solver, ts, Vs)
463464

464-
if !successful_solve(prob.wrapped_model)
465+
if !successful_solve(solved_model)
465466
ode_sol = SciMLBase.solution_new_retcode(
466467
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
467468
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
468469
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
469470
end
470-
DynamicOptSolution(prob.wrapped_model.model, ode_sol, input_sol)
471+
DynamicOptSolution(solved_model, ode_sol, input_sol)
471472
end

test/extensions/dynamic_optimization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using OrdinaryDiffEqSDIRK, OrdinaryDiffEqVerner, OrdinaryDiffEqTsit5, OrdinaryDi
66
using Ipopt
77
using DataInterpolations
88
using CasADi
9+
using Pyomo
910

1011
import DiffEqBase: solve
1112
const M = ModelingToolkit

0 commit comments

Comments
 (0)