Skip to content

Commit 1afe6c2

Browse files
Remove num_cons from OptimizationFunction
1 parent 8b2bcab commit 1afe6c2

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

src/function.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function default_chunk_size(len)
2222
end
2323
end
2424

25-
function instantiate_function(f, x, ::AbstractADType, p)
25+
function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
2626
grad = f.grad === nothing ? nothing : (G,x)->f.grad(G,x,p)
2727
hess = f.hess === nothing ? nothing : (H,x)->f.hess(H,x,p)
2828
hv = f.hv === nothing ? nothing : (H,x,v)->f.hv(H,x,v,p)
@@ -34,10 +34,10 @@ function instantiate_function(f, x, ::AbstractADType, p)
3434
typeof(hess),typeof(hv),typeof(cons),
3535
typeof(cons_j),typeof(cons_h)}(f.f,
3636
DiffEqBase.NoAD(),grad,hess,hv,cons,
37-
cons_j,cons_h,f.num_cons)
37+
cons_j,cons_h)
3838
end
3939

40-
function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p) where _chunksize
40+
function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
4141

4242
chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize
4343

@@ -86,7 +86,7 @@ function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p) where _chu
8686

8787
if cons !== nothing && f.cons_h === nothing
8888
cons_h = function (res, θ)
89-
for i in 1:f.num_cons
89+
for i in 1:num_cons
9090
hess_config_cache = ForwardDiff.HessianConfig(x -> cons(x)[i], θ,ForwardDiff.Chunk{chunksize}())
9191
ForwardDiff.hessian!(res[i], (x) -> cons(x)[i], θ, hess_config_cache,Val{false}())
9292
end
@@ -95,11 +95,11 @@ function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p) where _chu
9595
cons_h = f.cons_h
9696
end
9797

98-
return OptimizationFunction{true,AutoForwardDiff,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f,AutoForwardDiff(),grad,hess,hv,cons,cons_j,cons_h,f.num_cons)
98+
return OptimizationFunction{true,AutoForwardDiff,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f,AutoForwardDiff(),grad,hess,hv,cons,cons_j,cons_h)
9999
end
100100

101-
function instantiate_function(f, x, ::AutoZygote, p)
102-
f.num_cons != 0 && error("AutoZygote does not currently support constraints")
101+
function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
102+
num_cons != 0 && error("AutoZygote does not currently support constraints")
103103

104104
_f = θ -> f(θ,p)[1]
105105
if f.grad === nothing
@@ -135,11 +135,11 @@ function instantiate_function(f, x, ::AutoZygote, p)
135135
hv = f.hv
136136
end
137137

138-
return OptimizationFunction{false,AutoZygote,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoZygote(),grad,hess,hv,nothing,nothing,nothing,0)
138+
return OptimizationFunction{false,AutoZygote,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoZygote(),grad,hess,hv,nothing,nothing,nothing)
139139
end
140140

141-
function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParameters())
142-
f.num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
141+
function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParameters(), num_cons = 0)
142+
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
143143

144144
_f = θ -> f.f(θ,p)[1]
145145

@@ -177,12 +177,12 @@ function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParamete
177177
hv = f.hv
178178
end
179179

180-
return OptimizationFunction{false,AutoReverseDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoReverseDiff(),grad,hess,hv,nothing,nothing,nothing,0)
180+
return OptimizationFunction{false,AutoReverseDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoReverseDiff(),grad,hess,hv,nothing,nothing,nothing)
181181
end
182182

183183

184-
function instantiate_function(f, x, ::AutoTracker, p)
185-
f.num_cons != 0 && error("AutoTracker does not currently support constraints")
184+
function instantiate_function(f, x, ::AutoTracker, p, num_cons = 0)
185+
num_cons != 0 && error("AutoTracker does not currently support constraints")
186186
_f = θ -> f.f(θ,p)[1]
187187

188188
if f.grad === nothing
@@ -204,12 +204,11 @@ function instantiate_function(f, x, ::AutoTracker, p)
204204
end
205205

206206

207-
return OptimizationFunction{false,AutoTracker,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoTracker(),grad,hess,hv,nothing,nothing,nothing,0)
207+
return OptimizationFunction{false,AutoTracker,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoTracker(),grad,hess,hv,nothing,nothing,nothing)
208208
end
209209

210-
function instantiate_function(f, x, adtype::AutoFiniteDiff, p)
211-
212-
f.num_cons != 0 && error("AutoFiniteDiff does not currently support constraints")
210+
function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
211+
num_cons != 0 && error("AutoFiniteDiff does not currently support constraints")
213212
_f = θ -> f.f(θ,p)[1]
214213

215214
if f.grad === nothing
@@ -234,5 +233,5 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p)
234233
hv = f.hv
235234
end
236235

237-
return OptimizationFunction{false,AutoFiniteDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,adtype,grad,hess,hv,nothing,nothing,nothing,0)
236+
return OptimizationFunction{false,AutoFiniteDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,adtype,grad,hess,hv,nothing,nothing,nothing)
238237
end

src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function __solve(prob::OptimizationProblem, opt::Optim.ConstrainedOptimizer;cb =
215215
cb_call
216216
end
217217

218-
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p)
218+
f = instantiate_function(prob.f,prob.u0,prob.f.adtype,prob.p,prob.ucons === nothing ? 0 : length(prob.ucons))
219219

220220
f.cons_j ===nothing && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
221221

test/rosenbrock.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ sol = solve(prob, NelderMead())
2626
@test 10*sol.minimum < l1
2727

2828
cons= (x,p) -> [x[1]^2 + x[2]^2]
29-
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff();cons= cons, num_cons = 1)
29+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff();cons= cons)
3030

3131
prob = OptimizationProblem(optprob, x0)
3232
sol = solve(prob, BFGS())
@@ -54,7 +54,7 @@ function con2_c(x,p)
5454
[x[1]^2 + x[2]^2, x[2]*sin(x[1])-x[1]]
5555
end
5656

57-
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff();cons= con2_c, num_cons = 2)
57+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff();cons= con2_c)
5858
prob = OptimizationProblem(optprob, x0, lcons = [-Inf,-Inf], ucons = [Inf,Inf])
5959
sol = solve(prob, IPNewton())
6060
@test 10*sol.minimum < l1

0 commit comments

Comments
 (0)