Skip to content

Commit ddd89d7

Browse files
fix: update InfiniteOpt extension to new semantics
1 parent 28d4de5 commit ddd89d7

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

ext/MTKInfiniteOptExt.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ function MTK.JuMPDynamicOptProblem(sys::System, u0map, tspan, pmap;
6363
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
6464
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6565

66-
pmap = Dict{Any, Any}(pmap)
66+
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
67+
MTK.evaluate_varmap!(pmap, keys(pmap))
6768
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
6869
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
6970

@@ -89,7 +90,8 @@ function MTK.InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap;
8990
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
9091
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
9192

92-
pmap = Dict{Any, Any}(pmap)
93+
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
94+
MTK.evaluate_varmap!(pmap, keys(pmap))
9395
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
9496
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
9597

@@ -150,29 +152,28 @@ function set_jump_bounds!(model, sys, pmap)
150152
for (i, u) in enumerate(unknowns(sys))
151153
if MTK.hasbounds(u)
152154
lo, hi = MTK.getbounds(u)
153-
set_lower_bound(U[i], Symbolics.fixpoint_sub(lo, pmap))
154-
set_upper_bound(U[i], Symbolics.fixpoint_sub(hi, pmap))
155+
set_lower_bound(U[i], Symbolics.fast_substitute(lo, pmap))
156+
set_upper_bound(U[i], Symbolics.fast_substitute(hi, pmap))
155157
end
156158
end
157159

158160
V = model[:V]
159161
for (i, v) in enumerate(MTK.unbound_inputs(sys))
160162
if MTK.hasbounds(v)
161163
lo, hi = MTK.getbounds(v)
162-
set_lower_bound(V[i], Symbolics.fixpoint_sub(lo, pmap))
163-
set_upper_bound(V[i], Symbolics.fixpoint_sub(hi, pmap))
164+
set_lower_bound(V[i], Symbolics.fast_substitute(lo, pmap))
165+
set_upper_bound(V[i], Symbolics.fast_substitute(hi, pmap))
164166
end
165167
end
166168
end
167169

168170
function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free_t = false)
169-
jcosts = MTK.get_costs(sys)
170-
consolidate = MTK.get_consolidate(sys)
171-
if isnothing(jcosts) || isempty(jcosts)
171+
jcosts = cost(sys)
172+
if Symbolics._iszero(jcosts)
172173
@objective(model, Min, 0)
173174
return
174175
end
175-
jcosts = substitute_jump_vars(model, sys, pmap, jcosts; is_free_t)
176+
jcosts = substitute_jump_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
176177
tₛ = is_free_t ? model[:tf] : 1
177178

178179
# Substitute integral
@@ -187,17 +188,18 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
187188
hi = haskey(pmap, hi) ? 1 : MTK.value(hi)
188189
intmap[int] = tₛ * InfiniteOpt.(arg, model[:t], lo, hi)
189190
end
190-
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
191-
@objective(model, Min, consolidate(jcosts))
191+
jcosts = Symbolics.substitute(jcosts, intmap)
192+
@objective(model, Min, MTK.value(jcosts))
192193
end
193194

194195
function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = false)
195-
conssys = MTK.get_constraintsystem(sys)
196-
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
196+
jconstraints = MTK.get_constraints(sys)
197197
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
198+
cons_dvs, cons_ps = MTK.process_constraint_system(
199+
jconstraints, Set(unknowns(sys)), parameters(sys), MTK.get_iv(sys); validate = false)
198200

199201
if is_free_t
200-
for u in MTK.get_unknowns(conssys)
202+
for u in cons_dvs
201203
x = MTK.operation(u)
202204
t = only(arguments(u))
203205
if (MTK.symbolic_type(t) === MTK.NotSymbolic())
@@ -206,7 +208,7 @@ function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = fals
206208
end
207209
end
208210

209-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
211+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
210212
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints; auxmap, is_free_t)
211213

212214
# Substitute to-term'd variables
@@ -235,25 +237,25 @@ function substitute_jump_vars(model, sys, pmap, exprs; auxmap = Dict(), is_free_
235237
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
236238
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
237239

238-
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
239-
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
240+
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
241+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
240242
if is_free_t
241243
tf = model[:tf]
242244
free_t_map = Dict([[x(tf) => U[i](1) for (i, x) in enumerate(x_ops)];
243245
[c(tf) => V[i](1) for (i, c) in enumerate(c_ops)]])
244-
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
246+
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
245247
end
246248

247249
# for variables like x(t)
248250
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)];
249251
[v => V[i] for (i, v) in enumerate(cts)]])
250-
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
252+
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
251253

252254
# for variables like x(1.0)
253255
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)];
254256
[c_ops[i] => V[i] for i in 1:length(V)]])
255257

256-
exprs = map(c -> Symbolics.fixpoint_sub(c, fixed_t_map), exprs)
258+
exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map), exprs)
257259
exprs
258260
end
259261

0 commit comments

Comments
 (0)