Skip to content

Commit 2fa9b52

Browse files
authored
Merge pull request #1926 from SciML/myb/opt
Add structural_simplify support for OptimizationSystem
2 parents 9de96de + ccce0c3 commit 2fa9b52

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

src/systems/abstractsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,15 +1031,17 @@ This will convert all `inputs` to parameters and allow them to be unconnected, i
10311031
simplification will allow models where `n_states = n_equations - n_inputs`.
10321032
"""
10331033
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
1034-
simplify_constants = true, kwargs...)
1034+
simplify_constants = true, check_consistency = true, kwargs...)
10351035
sys = expand_connections(sys)
10361036
sys isa DiscreteSystem && return sys
10371037
state = TearingState(sys)
10381038
has_io = io !== nothing
10391039
has_io && markio!(state, io...)
10401040
state, input_idxs = inputs_to_parameters!(state, io)
10411041
sys, ag = alias_elimination!(state; kwargs...)
1042-
check_consistency(state, ag)
1042+
if check_consistency
1043+
ModelingToolkit.check_consistency(state, ag)
1044+
end
10431045
sys = dummy_derivative(sys, state, ag; simplify)
10441046
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
10451047
@set! sys.observed = topsort_equations(observed(sys), fullstates)

src/systems/optimization/optimizationsystem.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,32 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
543543
end
544544
end
545545
end
546+
547+
function structural_simplify(sys::OptimizationSystem; kwargs...)
548+
sys = flatten(sys)
549+
cons = constraints(sys)
550+
econs = Equation[]
551+
icons = similar(cons, 0)
552+
for e in cons
553+
if e isa Equation
554+
push!(econs, e)
555+
else
556+
push!(icons, e)
557+
end
558+
end
559+
nlsys = NonlinearSystem(econs, states(sys), parameters(sys); name = :___tmp_nlsystem)
560+
snlsys = structural_simplify(nlsys; check_consistency = false, kwargs...)
561+
obs = observed(snlsys)
562+
subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys))
563+
seqs = equations(snlsys)
564+
sizehint!(icons, length(icons) + length(seqs))
565+
for eq in seqs
566+
push!(icons, substitute(eq, subs))
567+
end
568+
newsts = setdiff(states(sys), keys(subs))
569+
@set! sys.constraints = icons
570+
@set! sys.observed = [observed(sys); obs]
571+
@set! sys.op = substitute(equations(sys), subs)
572+
@set! sys.states = newsts
573+
return sys
574+
end

test/optimizationsystem.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,27 @@ end
6969
end
7070

7171
@testset "equality constraint" begin
72-
@variables x y
72+
@variables x y z
7373
@parameters a b
74-
loss = (a - x)^2 + b * (y - x^2)^2
75-
cons = [1.0 ~ x^2 + y^2]
76-
@named sys = OptimizationSystem(loss, [x, y], [a, b], constraints = cons)
77-
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 1.0],
74+
loss = (a - x)^2 + b * z^2
75+
cons = [1.0 ~ x^2 + y^2
76+
z ~ y - x^2]
77+
@named sys = OptimizationSystem(loss, [x, y, z], [a, b], constraints = cons)
78+
sys = structural_simplify(sys)
79+
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0, z => 0.0], [a => 1.0, b => 1.0],
7880
grad = true, hess = true)
7981
sol = solve(prob, IPNewton())
8082
@test sol.minimum < 1.0
81-
@test sol.u[0.808, 0.589] atol=1e-3
82-
@test sol[x]^2 + sol[y]^2 1.0
83+
@test sol.u[0.808, -0.064] atol=1e-3
84+
@test_broken sol[x]^2 + sol[y]^2 1.0
8385
sol = solve(prob, Ipopt.Optimizer(); print_level = 0)
8486
@test sol.minimum < 1.0
85-
@test sol.u[0.808, 0.589] atol=1e-3
86-
@test sol[x]^2 + sol[y]^2 1.0
87+
@test sol.u[0.808, -0.064] atol=1e-3
88+
@test_broken sol[x]^2 + sol[y]^2 1.0
8789
sol = solve(prob, AmplNLWriter.Optimizer(Ipopt_jll.amplexe))
8890
@test sol.minimum < 1.0
89-
@test sol.u[0.808, 0.589] atol=1e-3
90-
@test sol[x]^2 + sol[y]^2 1.0
91+
@test sol.u[0.808, -0.064] atol=1e-3
92+
@test_broken sol[x]^2 + sol[y]^2 1.0
9193
end
9294

9395
@testset "rosenbrock" begin

0 commit comments

Comments
 (0)