Skip to content

Commit f118131

Browse files
Fix DAE formulation
1 parent 4868d6a commit f118131

File tree

2 files changed

+64
-85
lines changed

2 files changed

+64
-85
lines changed

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ end
5656

5757
function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
5858
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
59-
maxiters=nothing, differential_vars=nothing, kwargs...)
59+
maxiters=nothing, kwargs...)
6060
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
61-
maxiters=maxiters, differential_vars=differential_vars, kwargs...)
61+
maxiters=maxiters, kwargs...)
6262
end
6363

6464
function SciMLBase.__solve(
@@ -67,15 +67,14 @@ function SciMLBase.__solve(
6767

6868
dt = get(cache.solver_args, :dt, nothing)
6969
maxit = get(cache.solver_args, :maxiters, nothing)
70-
differential_vars = get(cache.solver_args, :differential_vars, nothing)
7170
u0 = copy(cache.u0)
7271
p = cache.p # Properly handle NullParameters
7372

7473
if cache.opt isa ODEOptimizer
7574
return solve_ode(cache, dt, maxit, u0, p)
7675
else
7776
if cache.opt.solver isa SciMLBase.AbstractDAEAlgorithm
78-
return solve_dae_implicit(cache, dt, maxit, u0, p, differential_vars)
77+
return solve_dae_implicit(cache, dt, maxit, u0, p)
7978
else
8079
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
8180
end
@@ -139,43 +138,39 @@ end
139138

140139
function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
141140
if cache.f.cons === nothing
142-
return solve_ode(cache, dt, maxit, u0, p)
141+
error("DAEOptimizer requires constraints. Please provide a function with `cons` defined.")
143142
end
144-
x=u0
145-
cons_vals = cache.f.cons(x, p)
146143
n = length(u0)
147-
m = length(cons_vals)
148-
u0_extended = vcat(u0, zeros(m))
149-
M = Diagonal(ones(n + m))
150-
144+
m = length(cache.ucons)
151145

146+
if m > n
147+
error("DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables.")
148+
end
149+
M = Diagonal([ones(n-m); zeros(m)])
152150
function f_mass!(du, u, p_, t)
153-
x = @view u[1:n]
154-
λ = @view u[n+1:end]
155-
grad_f = similar(x)
156-
if cache.f.grad !== nothing
157-
cache.f.grad(grad_f, x, p_)
158-
else
159-
grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x)
160-
end
161-
J = Matrix{eltype(x)}(undef, m, n)
162-
cache.f.cons_j !== nothing && cache.f.cons_j(J, x)
163-
164-
@. du[1:n] = -grad_f - (J' * λ)
165-
consv = cache.f.cons(x, p_)
166-
@. du[n+1:end] = consv
151+
cache.f.grad(du, u, p)
152+
@. du = -du
153+
consout = @view du[(n-m)+1:end]
154+
cache.f.cons(consout, u)
167155
return nothing
168156
end
169157

170-
if m == 0
171-
optf = ODEFunction(f_mass!)
172-
prob = ODEProblem(optf, u0, (0.0, 1.0), p)
173-
return solve(prob, cache.opt.solver; dt=dt, maxiters=maxit)
174-
end
175-
176-
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p)
158+
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0, p)
177159

178-
solve_kwargs = setup_progress_callback(cache, Dict())
160+
if cache.callback !== Optimization.DEFAULT_CALLBACK
161+
condition = (u, t, integrator) -> true
162+
affect! = (integrator) -> begin
163+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
164+
l = cache.f(integrator.u, integrator.p)
165+
cache.callback(integrator.u, l)
166+
end
167+
cb = DiscreteCallback(condition, affect!)
168+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
169+
else
170+
solve_kwargs = Dict{Symbol, Any}()
171+
end
172+
173+
solve_kwargs[:progress] = cache.progress
179174
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
180175
if dt !== nothing; solve_kwargs[:dt] = dt; end
181176

@@ -189,61 +184,48 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
189184
retcode = sol.retcode)
190185
end
191186

192-
193-
function solve_dae_implicit(cache, dt, maxit, u0, p, differential_vars)
187+
function solve_dae_implicit(cache, dt, maxit, u0, p)
194188
if cache.f.cons === nothing
195-
return solve_ode(cache, dt, maxit, u0, p)
189+
error("DAEOptimizer requires constraints. Please provide a function with `cons` defined.")
196190
end
197-
x=u0
198-
cons_vals = cache.f.cons(x, p)
191+
199192
n = length(u0)
200-
m = length(cons_vals)
201-
u0_ext = vcat(u0, zeros(m))
202-
du0_ext = zeros(n + m)
193+
m = length(cache.ucons)
203194

204-
if differential_vars === nothing
205-
differential_vars = vcat(fill(true, n), fill(false, m))
206-
else
207-
if length(differential_vars) == n
208-
differential_vars = vcat(differential_vars, fill(false, m))
209-
elseif length(differential_vars) == n + m
210-
# use as is
211-
else
212-
error("differential_vars length must be number of variables ($n) or extended size ($(n+m))")
213-
end
195+
if m > n
196+
error("DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables.")
214197
end
215198

216199
function dae_residual!(res, du, u, p_, t)
217-
x = @view u[1:n]
218-
λ = @view u[n+1:end]
219-
du_x = @view du[1:n]
220-
grad_f = similar(x)
221-
cache.f.grad(grad_f, x, p_)
222-
J = zeros(m, n)
223-
cache.f.cons_j !== nothing && cache.f.cons_j(J, x)
224-
225-
@. res[1:n] = du_x + grad_f + J' * λ
226-
consv = cache.f.cons(x, p_)
227-
@. res[n+1:end] = consv
200+
cache.f.grad(res, u, p)
201+
@. res = du-res
202+
consout = @view res[(n-m)+1:end]
203+
cache.f.cons(consout, u)
228204
return nothing
229205
end
230206

231-
if m == 0
232-
optf = ODEFunction(dae_residual!, differential_vars = differential_vars)
233-
prob = ODEProblem(optf, du0_ext, (0.0, 1.0), p)
234-
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
235-
end
236-
237207
tspan = (0.0, 10.0)
238-
prob = DAEProblem(dae_residual!, du0_ext, u0_ext, tspan, p;
239-
differential_vars = differential_vars)
208+
du0 = zero(u0)
209+
prob = DAEProblem(dae_residual!, du0, u0, tspan, p)
210+
211+
if cache.callback !== Optimization.DEFAULT_CALLBACK
212+
condition = (u, t, integrator) -> true
213+
affect! = (integrator) -> begin
214+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
215+
l = cache.f(integrator.u, integrator.p)
216+
cache.callback(integrator.u, l)
217+
end
218+
cb = DiscreteCallback(condition, affect!)
219+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
220+
else
221+
solve_kwargs = Dict{Symbol, Any}()
222+
end
223+
224+
solve_kwargs[:progress] = cache.progress
240225

241-
solve_kwargs = setup_progress_callback(cache, Dict())
242226
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
243227
if dt !== nothing; solve_kwargs[:dt] = dt; end
244-
if hasfield(typeof(cache.opt.solver), :initializealg)
245-
solve_kwargs[:initializealg] = BrownFullBasicInit()
246-
end
228+
solve_kwargs[:initializealg] = ShampineCollocationInit()
247229

248230
sol = solve(prob, cache.opt.solver; solve_kwargs...)
249231
u_ext = sol.u

lib/OptimizationODE/test/runtests.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ end
3535

3636
function constraint_func(res, x, p)
3737
res[1] = x[1] + x[2] - 1.0 # x[1] + x[2] = 1
38-
return x[1] + x[2] - 1.0
3938
end
4039

4140
function constraint_jac!(jac, x, p)
@@ -107,7 +106,7 @@ end
107106
end
108107

109108
function constrained_objective_grad!(g, x, p)
110-
g .= 2 .* x .* p[1]
109+
g .= 2 .* x
111110
return nothing
112111
end
113112

@@ -122,7 +121,7 @@ end
122121
return nothing
123122
end
124123

125-
x0 = [1.0, 0.0] # reasonable initial guess
124+
x0 = [1.0, 1.0] # reasonable initial guess
126125
p = [1.0] # enforce x₁ - x₂ = 1
127126

128127
optf = OptimizationFunction(constrained_objective;
@@ -131,21 +130,19 @@ end
131130
cons_j = constraint_jac!)
132131

133132
@testset "Equality Constrained - Mass Matrix Method" begin
134-
prob = OptimizationProblem(optf, x0, p)
133+
prob = OptimizationProblem(optf, x0, p, lcons = [-10.0], ucons = [10.0])
135134
opt = DAEMassMatrix()
136135
sol = solve(prob, opt; dt=0.01, maxiters=1_000_000)
137136

138137
@test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default
139-
@test isapprox(sol.u[1] - sol.u[2], 1.0; atol = 1e-2)
140-
@test isapprox(sol.u, [0.5, -0.5]; atol = 1e-2)
138+
@test isapprox(sol.u[1] + sol.u[2], 1.0; atol = 1e-2)
139+
@test_broken isapprox(sol.u, [0.5, 0.5]; atol = 1e-2)
141140
end
142141

143-
@testset "Equality Constrained - Index Method" begin
144-
prob = OptimizationProblem(optf, x0, p)
142+
@testset "Equality Constrained - Fully Implicit Method" begin
143+
prob = OptimizationProblem(optf, x0, p, lcons = [-10.0], ucons = [10.0])
145144
opt = DAEOptimizer(IDA())
146-
differential_vars = [true, true, false] # x vars = differential, λ = algebraic
147-
sol = solve(prob, opt; dt=0.01, maxiters=1_000_000,
148-
differential_vars = differential_vars)
145+
sol = solve(prob, opt; dt=0.01, maxiters=1_000_000)
149146

150147
@test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default
151148
@test isapprox(sol.u[1] - sol.u[2], 1.0; atol = 1e-2)

0 commit comments

Comments
 (0)