Skip to content

Commit 0c83fdf

Browse files
Merge pull request #1972 from ValentinKaisermayer/vk-enable-remake-for-opt-problem
Enable symbolic remake for optimization problems
2 parents 48ac4f3 + 27826c5 commit 0c83fdf

File tree

5 files changed

+40
-16
lines changed

5 files changed

+40
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ NaNMath = "0.3, 1"
7171
RecursiveArrayTools = "2.3"
7272
Reexport = "0.2, 1"
7373
RuntimeGeneratedFunctions = "0.4.3, 0.5"
74-
SciMLBase = "1.75.0"
74+
SciMLBase = "1.76.1"
7575
Setfield = "0.7, 0.8, 1"
7676
SimpleNonlinearSolve = "0.1.0"
7777
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"

src/variables.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,16 @@ end
119119
throw(ArgumentError("$vars are missing from the variable map."))
120120
end
121121

122-
# FIXME: remove after: https://github.com/SciML/SciMLBase.jl/pull/311
123-
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
124-
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
125-
return out
126-
end
127-
128122
"""
129123
$(SIGNATURES)
130124
131125
Intercept the call to `process_p_u0_symbolic` and process symbolic maps of `p` and/or `u0` if the
132126
user has `ModelingToolkit` loaded.
133127
"""
134-
function SciMLBase.process_p_u0_symbolic(prob::ODEProblem, p, u0)
128+
function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem,
129+
NonlinearProblem, OptimizationProblem},
130+
p,
131+
u0)
135132
# check if a symbolic remake is possible
136133
if eltype(p) <: Pair
137134
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||

test/nonlinearsystem.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ eq = [v1 ~ sin(2pi * t * h)
189189
@named sys = ODESystem(eq)
190190
@test length(equations(structural_simplify(sys))) == 0
191191

192-
#1504
193-
let
192+
@testset "Issue: 1504" begin
194193
@variables u[1:4]
195194

196195
eqs = [u[1] ~ 1,
@@ -213,3 +212,23 @@ eqs = [0 ~ a * x]
213212
testdict = Dict([:test => 1])
214213
@named sys = NonlinearSystem(eqs, [x], [a], metadata = testdict)
215214
@test get_metadata(sys) == testdict
215+
216+
@testset "Remake" begin
217+
@parameters a=1.0 b=1.0 c=1.0
218+
@constants h = 1
219+
@variables x y z
220+
221+
eqs = [0 ~ a * (y - x) * h,
222+
0 ~ x * (b - z) - y,
223+
0 ~ x * y - c * z]
224+
@named sys = NonlinearSystem(eqs, [x, y, z], [a, b, c], defaults = Dict(x => 2.0))
225+
prob = NonlinearProblem(sys, ones(length(states(sys))))
226+
227+
prob_ = remake(prob, u0 = [1.0, 2.0, 3.0], p = [1.1, 1.2, 1.3])
228+
@test prob_.u0 == [1.0, 2.0, 3.0]
229+
@test prob_.p == [1.1, 1.2, 1.3]
230+
231+
prob_ = remake(prob, u0 = Dict(y => 2.0), p = Dict(a => 2.0))
232+
@test prob_.u0 == [1.0, 2.0, 1.0]
233+
@test prob_.p == [2.0, 1.0, 1.0]
234+
end

test/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ sol_dpmap = solve(prob_dpmap, Rodas5())
277277
prob = ODEProblem(sys, Pair[])
278278
prob_new = SciMLBase.remake(prob, p = Dict(sys1.a => 3.0, b => 4.0),
279279
u0 = Dict(sys1.x => 1.0))
280-
@test_broken prob_new.p == [4.0, 3.0, 1.0]
281-
@test_broken prob_new.u0 == [1.0, 0.0]
280+
@test prob_new.p == [4.0, 3.0, 1.0]
281+
@test prob_new.u0 == [1.0, 0.0]
282282
end
283283

284284
# test kwargs

test/optimizationsystem.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,23 +137,31 @@ end
137137
# nested constraints
138138
@testset "nested systems" begin
139139
@variables x y
140-
o1 = (x - 1)^2
140+
@parameters a = 1
141+
o1 = (x - a)^2
141142
o2 = (y - 1 / 2)^2
142143
c1 = [
143144
x ~ 1,
144145
]
145146
c2 = [
146147
y ~ 1,
147148
]
148-
sys1 = OptimizationSystem(o1, [x], [], name = :sys1, constraints = c1)
149+
sys1 = OptimizationSystem(o1, [x], [a], name = :sys1, constraints = c1)
149150
sys2 = OptimizationSystem(o2, [y], [], name = :sys2, constraints = c2)
150151
sys = OptimizationSystem(0, [], []; name = :sys, systems = [sys1, sys2],
151152
constraints = [sys1.x + sys2.y ~ 2], checks = false)
152153
prob = OptimizationProblem(sys, [0.0, 0.0])
153-
154154
@test isequal(constraints(sys), vcat(sys1.x + sys2.y ~ 2, sys1.x ~ 1, sys2.y ~ 1))
155-
@test isequal(equations(sys), (sys1.x - 1)^2 + (sys2.y - 1 / 2)^2)
155+
@test isequal(equations(sys), (sys1.x - sys1.a)^2 + (sys2.y - 1 / 2)^2)
156156
@test isequal(states(sys), [sys1.x, sys2.y])
157+
158+
prob_ = remake(prob, u0 = [1.0, 0.0], p = [2.0])
159+
@test isequal(prob_.u0, [1.0, 0.0])
160+
@test isequal(prob_.p, [2.0])
161+
162+
prob_ = remake(prob, u0 = Dict(sys1.x => 1.0), p = Dict(sys1.a => 2.0))
163+
@test isequal(prob_.u0, [1.0, 0.0])
164+
@test isequal(prob_.p, [2.0])
157165
end
158166

159167
@testset "time dependent var" begin

0 commit comments

Comments
 (0)