Skip to content

Commit 044dfb2

Browse files
enable remake for optimization problems
1 parent b384ad1 commit 044dfb2

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

src/variables.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,14 @@ end
119119
throw(ArgumentError("$vars are missing from the variable map."))
120120
end
121121

122-
# FIXME: remove after: https://github.com/SciML/SciMLBase.jl/pull/311
123-
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
124-
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
125-
return out
126-
end
127-
128122
"""
129123
$(SIGNATURES)
130124
131125
Intercept the call to `process_p_u0_symbolic` and process symbolic maps of `p` and/or `u0` if the
132126
user has `ModelingToolkit` loaded.
133127
"""
134-
function SciMLBase.process_p_u0_symbolic(prob::ODEProblem, p, u0)
128+
function SciMLBase.process_p_u0_symbolic(prob::Union{ODEProblem, OptimizationProblem}, p,
129+
u0)
135130
# check if a symbolic remake is possible
136131
if eltype(p) <: Pair
137132
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||

test/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ sol_dpmap = solve(prob_dpmap, Rodas5())
277277
prob = ODEProblem(sys, Pair[])
278278
prob_new = SciMLBase.remake(prob, p = Dict(sys1.a => 3.0, b => 4.0),
279279
u0 = Dict(sys1.x => 1.0))
280-
@test_broken prob_new.p == [4.0, 3.0, 1.0]
281-
@test_broken prob_new.u0 == [1.0, 0.0]
280+
@test prob_new.p == [4.0, 3.0, 1.0]
281+
@test prob_new.u0 == [1.0, 0.0]
282282
end
283283

284284
# test kwargs

test/optimizationsystem.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,23 +137,31 @@ end
137137
# nested constraints
138138
@testset "nested systems" begin
139139
@variables x y
140-
o1 = (x - 1)^2
140+
@parameters a=1
141+
o1 = (x - a)^2
141142
o2 = (y - 1 / 2)^2
142143
c1 = [
143144
x ~ 1,
144145
]
145146
c2 = [
146147
y ~ 1,
147148
]
148-
sys1 = OptimizationSystem(o1, [x], [], name = :sys1, constraints = c1)
149+
sys1 = OptimizationSystem(o1, [x], [a], name = :sys1, constraints = c1)
149150
sys2 = OptimizationSystem(o2, [y], [], name = :sys2, constraints = c2)
150151
sys = OptimizationSystem(0, [], []; name = :sys, systems = [sys1, sys2],
151152
constraints = [sys1.x + sys2.y ~ 2], checks = false)
152153
prob = OptimizationProblem(sys, [0.0, 0.0])
153-
154154
@test isequal(constraints(sys), vcat(sys1.x + sys2.y ~ 2, sys1.x ~ 1, sys2.y ~ 1))
155155
@test isequal(equations(sys), (sys1.x - 1)^2 + (sys2.y - 1 / 2)^2)
156156
@test isequal(states(sys), [sys1.x, sys2.y])
157+
158+
prob_ = remake(prob, u0=[1.0, 0.0], p=[2.0])
159+
@test isequal(prob_.u0, [1.0, 0.0])
160+
@test isequal(prob_.p, [2.0])
161+
162+
prob_ = remake(prob, u0=Dict(sys1.x => 1.0), p=Dict(sys1.a => 2.0))
163+
@test isequal(prob_.u0, [1.0, 0.0])
164+
@test isequal(prob_.p, [2.0])
157165
end
158166

159167
@testset "time dependent var" begin

0 commit comments

Comments
 (0)