Skip to content

Commit 526a8a9

Browse files
Merge branch 'daenew' of https://github.com/ParasPuneetSingh/Optimization.jl into daenew
2 parents 7658f9b + 11bd665 commit 526a8a9

File tree

3 files changed

+35
-125
lines changed

3 files changed

+35
-125
lines changed

lib/OptimizationODE/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ version = "0.1.0"
66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1011
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
14+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
15+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
16+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
1317

1418
[compat]
1519
ForwardDiff = "1"

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 12 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct DAEOptimizer{T}
2323
solver::T
2424
end
2525

26-
DAEMassMatrix() = DAEOptimizer(Rodas5())
26+
DAEMassMatrix() = DAEOptimizer(Rosenbrock23(autodiff = false))
2727
DAEIndexing() = DAEOptimizer(IDA())
2828

2929

@@ -63,32 +63,6 @@ function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
6363
end
6464

6565

66-
function solve_constrained_root(cache, u0, p)
67-
n = length(u0)
68-
cons_vals = cache.f.cons(u0, p)
69-
m = length(cons_vals)
70-
function resid!(res, u)
71-
temp = similar(u)
72-
f_mass!(temp, u, p, 0.0)
73-
res .= temp
74-
end
75-
u0_ext = vcat(u0, zeros(m))
76-
prob_nl = NonlinearProblem(resid!, u0_ext, p)
77-
sol_nl = solve(prob_nl, Newton(); tol = 1e-8, maxiters = 100000,
78-
callback = cache.callback, progress = get(cache.solver_args, :progress, false))
79-
u_ext = sol_nl.u
80-
return u_ext[1:n], sol_nl.retcode
81-
end
82-
83-
84-
function get_solver_type(opt::DAEOptimizer)
85-
if opt.solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid}
86-
return :mass_matrix
87-
else
88-
return :indexing
89-
end
90-
end
91-
9266
function handle_parameters(p)
9367
if p isa SciMLBase.NullParameters
9468
return Float64[]
@@ -110,45 +84,6 @@ function setup_progress_callback(cache, solve_kwargs)
11084
return solve_kwargs
11185
end
11286

113-
function finite_difference_jacobian(f, x; ϵ = 1e-8)
114-
n = length(x)
115-
fx = f(x)
116-
if fx === nothing
117-
return zeros(eltype(x), 0, n)
118-
elseif isa(fx, Number)
119-
J = zeros(eltype(fx), 1, n)
120-
for j in 1:n
121-
xj = copy(x)
122-
xj[j] += ϵ
123-
diff = f(xj)
124-
if diff === nothing
125-
diffval = zero(eltype(fx))
126-
else
127-
diffval = diff - fx
128-
end
129-
J[1, j] = diffval / ϵ
130-
end
131-
return J
132-
else
133-
m = length(fx)
134-
J = zeros(eltype(fx), m, n)
135-
for j in 1:n
136-
xj = copy(x)
137-
xj[j] += ϵ
138-
fxj = f(xj)
139-
if fxj === nothing
140-
@inbounds for i in 1:m
141-
J[i, j] = -fx[i] / ϵ
142-
end
143-
else
144-
@inbounds for i in 1:m
145-
J[i, j] = (fxj[i] - fx[i]) / ϵ
146-
end
147-
end
148-
end
149-
return J
150-
end
151-
end
15287

15388
function SciMLBase.__solve(
15489
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
@@ -163,8 +98,7 @@ function SciMLBase.__solve(
16398
if cache.opt isa ODEOptimizer
16499
return solve_ode(cache, dt, maxit, u0, p)
165100
else
166-
solver_method = get_solver_type(cache.opt)
167-
if solver_method == :mass_matrix
101+
if cache.opt.solver == Rosenbrock23(autodiff = false)
168102
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
169103
else
170104
return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
@@ -240,8 +174,8 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
240174
n = length(u0)
241175
m = length(cons_vals)
242176
u0_extended = vcat(u0, zeros(m))
243-
M = zeros(n + m, n + m)
244-
M[1:n, 1:n] = I(n)
177+
M = Diagonal(ones(n + m))
178+
245179

246180
function f_mass!(du, u, p_, t)
247181
x = @view u[1:n]
@@ -253,31 +187,18 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
253187
grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x)
254188
end
255189
J = Matrix{eltype(x)}(undef, m, n)
256-
if cache.f.cons_j !== nothing
257-
cache.f.cons_j(J, x)
258-
else
259-
J .= finite_difference_jacobian(z -> cache.f.cons(z, p_), x)
260-
end
190+
cache.f.cons_j !== nothing && cache.f.cons_j(J, x)
191+
261192
@. du[1:n] = -grad_f - (J' * λ)
262193
consv = cache.f.cons(x, p_)
263-
if consv === nothing
264-
fill!(du[n+1:end], zero(eltype(x)))
265-
else
266-
if isa(consv, Number)
267-
@assert m == 1
268-
du[n+1] = consv
269-
else
270-
@assert length(consv) == m
271-
@. du[n+1:end] = consv
272-
end
273-
end
194+
@. du[n+1:end] = consv
274195
return nothing
275196
end
276197

277198
if m == 0
278-
optf = ODEFunction(f_mass!, mass_matrix = I(n))
199+
optf = ODEFunction(f_mass!)
279200
prob = ODEProblem(optf, u0, (0.0, 1.0), p)
280-
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
201+
return solve(prob, cache.opt.solver; dt=dt, maxiters=maxit)
281202
end
282203

283204
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p)
@@ -327,11 +248,8 @@ function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
327248
grad_f = similar(x)
328249
cache.f.grad(grad_f, x, p_)
329250
J = zeros(m, n)
330-
if cache.f.cons_j !== nothing
331-
cache.f.cons_j(J, x)
332-
else
333-
J .= finite_difference_jacobian(z -> cache.f.cons(z,p_), x)
334-
end
251+
cache.f.cons_j !== nothing && cache.f.cons_j(J, x)
252+
335253
@. res[1:n] = du_x + grad_f + J' * λ
336254
consv = cache.f.cons(x, p_)
337255
@. res[n+1:end] = consv
@@ -364,4 +282,4 @@ function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
364282
end
365283

366284

367-
end
285+
end

lib/OptimizationODE/test/runtests.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,40 @@ using LinearAlgebra, ForwardDiff
55
using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
66

77
# Test helper functions
8-
function rosenbrock(x, p,args...)
8+
function rosenbrock(x, p)
99
return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
1010
end
1111

12-
function rosenbrock_grad!(grad, x, p,args...)
12+
function rosenbrock_grad!(grad, x, p)
1313
grad[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * (x[2] - x[1]^2) * x[1]
1414
grad[2] = 2.0 * p[2] * (x[2] - x[1]^2)
1515
end
1616

17-
function quadratic(x, p,args...)
17+
function quadratic(x, p)
1818
return (x[1] - p[1])^2 + (x[2] - p[2])^2
1919
end
2020

21-
function quadratic_grad!(grad, x, p,args...)
21+
function quadratic_grad!(grad, x, p)
2222
grad[1] = 2.0 * (x[1] - p[1])
2323
grad[2] = 2.0 * (x[2] - p[2])
2424
end
2525

2626
# Constrained optimization problem
27-
function constrained_objective(x, p,args...)
27+
function constrained_objective(x, p)
2828
return x[1]^2 + x[2]^2
2929
end
3030

31-
function constrained_objective_grad!(grad, x, p,args...)
31+
function constrained_objective_grad!(grad, x, p)
3232
grad[1] = 2.0 * x[1]
3333
grad[2] = 2.0 * x[2]
3434
end
3535

36-
function constraint_func(res, x, p,args...)
36+
function constraint_func(res, x, p)
3737
res[1] = x[1] + x[2] - 1.0 # x[1] + x[2] = 1
3838
return x[1] + x[2] - 1.0
3939
end
4040

41-
function constraint_jac!(jac, x, p,args...)
41+
function constraint_jac!(jac, x, p)
4242
jac[1, 1] = 1.0
4343
jac[1, 2] = -1.0
4444
end
@@ -102,21 +102,21 @@ end
102102
# Minimize f(x) = x₁² + x₂²
103103
# Subject to x₁ - x₂ = 1
104104

105-
function constrained_objective(x, p,args...)
105+
function constrained_objective(x, p)
106106
return x[1]^2 + x[2]^2
107107
end
108108

109-
function constrained_objective_grad!(g, x, p, args...)
109+
function constrained_objective_grad!(g, x, p)
110110
g .= 2 .* x .* p[1]
111111
return nothing
112112
end
113113

114114
# Constraint: x₁ - x₂ - p[1] = 0 (p[1] = 1 → x₁ - x₂ = 1)
115-
function constraint_func(x, p, args...)
115+
function constraint_func(x, p)
116116
return x[1] - x[2] - p[1]
117117
end
118118

119-
function constraint_jac!(J, x,args...)
119+
function constraint_jac!(J, x)
120120
J[1, 1] = 1.0
121121
J[1, 2] = -1.0
122122
return nothing
@@ -159,8 +159,8 @@ end
159159
x0 = [0.0, 0.0]
160160
p=Float64[] # No parameters provided
161161
# Create a problem with NullParameters
162-
optf = OptimizationFunction((x, p, args...) -> sum(x.^2),
163-
grad=(grad, x, p, args...) -> (grad .= 2.0 .* x))
162+
optf = OptimizationFunction((x, p) -> sum(x.^2),
163+
grad=(grad, x, p) -> (grad .= 2.0 .* x))
164164
prob = OptimizationProblem(optf, x0,p) # No parameters provided
165165

166166
opt = ODEGradientDescent()
@@ -233,26 +233,14 @@ end
233233
x = [1.0, 2.0]
234234
f(x) = [x[1]^2 + x[2], x[1] * x[2]]
235235

236-
J = OptimizationODE.finite_difference_jacobian(f, x)
236+
J = ForwardDiff.jacobian(f, x)
237237

238238
expected_J = [2.0 1.0; 2.0 1.0]
239239

240240
@test isapprox(J, expected_J, atol=1e-6)
241241
end
242242
end
243-
244-
@testset "Solver Type Detection" begin
245-
@testset "Mass Matrix Solvers" begin
246-
opt = DAEMassMatrix()
247-
@test OptimizationODE.get_solver_type(opt) == :mass_matrix
248-
end
249-
250-
@testset "Index Method Solvers" begin
251-
opt = DAEIndexing()
252-
@test OptimizationODE.get_solver_type(opt) == :indexing
253-
end
254-
end
255-
243+
256244
@testset "Error Handling and Edge Cases" begin
257245
@testset "Empty Constraints" begin
258246
x0 = [1.5, 0.5]
@@ -274,8 +262,8 @@ end
274262
x0 = [0.5]
275263
p = [1.0]
276264

277-
single_var_func(x, p,args...) = (x[1] - p[1])^2
278-
single_var_grad!(grad, x, p,args...) = (grad[1] = 2.0 * (x[1] - p[1]))
265+
single_var_func(x, p) = (x[1] - p[1])^2
266+
single_var_grad!(grad, x, p) = (grad[1] = 2.0 * (x[1] - p[1]))
279267

280268
optf = OptimizationFunction(single_var_func; grad=single_var_grad!)
281269
prob = OptimizationProblem(optf, x0, p)
@@ -287,4 +275,4 @@ end
287275
@test isapprox(sol.u[1], p[1], atol=1e-1)
288276
end
289277
end
290-
end
278+
end

0 commit comments

Comments
 (0)