Skip to content

Commit b5288ad

Browse files
fix: do not drop zeros when constructing prototype buffers
1 parent 3f1c176 commit b5288ad

File tree

7 files changed

+24
-24
lines changed

7 files changed

+24
-24
lines changed

lib/OptimizationBase/ext/OptimizationMTKExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function OptimizationBase.instantiate_function(
3131
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
3232

3333
hv = function (H, θ, v, args...)
34-
res = (eltype(θ)).(f.hess_prototype)
34+
res = similar(f.hess_prototype, eltype(θ))
3535
hess(res, θ, args...)
3636
H .= res * v
3737
end
@@ -83,7 +83,7 @@ function OptimizationBase.instantiate_function(
8383
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
8484

8585
hv = function (H, θ, v, args...)
86-
res = (eltype(θ)).(f.hess_prototype)
86+
res = similar(f.hess_prototype, eltype(θ))
8787
hess(res, θ, args...)
8888
H .= res * v
8989
end

lib/OptimizationBase/src/function.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ function OptimizationBase.instantiate_function(
5656
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
5757
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
5858
hess_prototype = f.hess_prototype === nothing ? nothing :
59-
convert.(eltype(x), f.hess_prototype)
59+
similar(f.hess_prototype, eltype(x))
6060
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
61-
convert.(eltype(x), f.cons_jac_prototype)
61+
similar(f.cons_jac_prototype, eltype(x))
6262
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
63-
[convert.(eltype(x), f.cons_hess_prototype[i])
63+
[similar(f.cons_hess_prototype[i], eltype(x))
6464
for i in 1:num_cons]
6565
expr = symbolify(f.expr)
6666
cons_expr = symbolify.(f.cons_expr)
@@ -90,11 +90,11 @@ function OptimizationBase.instantiate_function(
9090
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, cache.p)
9191
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
9292
hess_prototype = f.hess_prototype === nothing ? nothing :
93-
convert.(eltype(cache.u0), f.hess_prototype)
93+
similar(f.hess_prototype, eltype(cache.u0))
9494
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
95-
convert.(eltype(cache.u0), f.cons_jac_prototype)
95+
similar(f.cons_jac_prototype, eltype(cache.u0))
9696
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
97-
[convert.(eltype(cache.u0), f.cons_hess_prototype[i])
97+
[similar(f.cons_hess_prototype[i], eltype(cache.u0))
9898
for i in 1:num_cons]
9999
expr = symbolify(f.expr)
100100
cons_expr = symbolify.(f.cons_expr)
@@ -196,11 +196,11 @@ function OptimizationBase.instantiate_function(
196196
end
197197
end
198198
hess_prototype = f.hess_prototype === nothing ? nothing :
199-
convert.(eltype(x), f.hess_prototype)
199+
similar(f.hess_prototype, eltype(x))
200200
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
201-
convert.(eltype(x), f.cons_jac_prototype)
201+
similar(f.cons_jac_prototype, eltype(x))
202202
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
203-
[convert.(eltype(x), f.cons_hess_prototype[i])
203+
[similar(f.cons_hess_prototype[i], eltype(x))
204204
for i in 1:num_cons]
205205
expr = symbolify(f.expr)
206206
cons_expr = symbolify.(f.cons_expr)

lib/OptimizationIpopt/src/cache.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ function IpoptCache(prob, opt;
9494
J = if isnothing(f.cons_jac_prototype)
9595
zeros(T, num_cons, n)
9696
else
97-
convert.(T, f.cons_jac_prototype)
97+
similar(f.cons_jac_prototype, T)
9898
end
9999
lagh = !isnothing(f.lag_hess_prototype)
100100
H = if lagh # lag hessian takes precedence
101-
convert.(T, f.lag_hess_prototype)
101+
similar(f.lag_hess_prototype, T)
102102
elseif !isnothing(f.hess_prototype)
103-
convert.(T, f.hess_prototype)
103+
similar(f.hess_prototype, T)
104104
else
105105
zeros(T, n, n)
106106
end
@@ -109,7 +109,7 @@ function IpoptCache(prob, opt;
109109
elseif isnothing(f.cons_hess_prototype)
110110
Matrix{T}[zeros(T, n, n) for i in 1:num_cons]
111111
else
112-
[convert.(T, f.cons_hess_prototype[i]) for i in 1:num_cons]
112+
[similar(f.cons_hess_prototype[i], T) for i in 1:num_cons]
113113
end
114114
lcons = prob.lcons === nothing ? fill(T(-Inf), num_cons) : prob.lcons
115115
ucons = prob.ucons === nothing ? fill(T(Inf), num_cons) : prob.ucons

lib/OptimizationMOI/src/nlp.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,17 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem,
123123
end
124124
T = eltype(prob.u0)
125125
n = length(prob.u0)
126-
127126
J = if isnothing(f.cons_jac_prototype)
128127
zeros(T, num_cons, n)
129128
else
130-
convert.(T, f.cons_jac_prototype)
129+
similar(f.cons_jac_prototype, T)
131130
end
132131
lagh = !isnothing(f.lag_hess_prototype)
132+
133133
H = if lagh # lag hessian takes precedence
134-
convert.(T, f.lag_hess_prototype)
134+
similar(f.lag_hess_prototype, T)
135135
elseif !isnothing(f.hess_prototype)
136-
convert.(T, f.hess_prototype)
136+
similar(f.hess_prototype, T)
137137
else
138138
zeros(T, n, n)
139139
end
@@ -142,7 +142,7 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem,
142142
elseif isnothing(f.cons_hess_prototype)
143143
Matrix{T}[zeros(T, n, n) for i in 1:num_cons]
144144
else
145-
[convert.(T, f.cons_hess_prototype[i]) for i in 1:num_cons]
145+
[similar(f.cons_hess_prototype[i], T) for i in 1:num_cons]
146146
end
147147
lcons = prob.lcons === nothing ? fill(T(-Inf), num_cons) : prob.lcons
148148
ucons = prob.ucons === nothing ? fill(T(Inf), num_cons) : prob.ucons

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
213213
isnothing(cache.f.hess_prototype) ?
214214
Optim.NLSolversBase.alloc_H(cache.u0,
215215
real(zero(u0_type))) :
216-
convert.(u0_type, cache.f.hess_prototype))
216+
similar(cache.f.hess_prototype, u0_type))
217217
end
218218

219219
opt_args = __map_optimizer_args(cache, cache.opt, callback = _cb,
@@ -414,7 +414,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
414414
isnothing(cache.f.hess_prototype) ?
415415
Optim.NLSolversBase.alloc_H(cache.u0,
416416
real(zero(u0_type))) :
417-
convert.(u0_type, cache.f.hess_prototype))
417+
similar(cache.f.hess_prototype, u0_type))
418418
else
419419
Optim.OnceDifferentiable(_loss, gg, fg!, cache.u0,
420420
real(zero(u0_type)),

src/auglag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
128128
function aug_grad(G, θ, p)
129129
cache.f.grad(G, θ, p)
130130
if !isnothing(cache.f.cons_jac_prototype)
131-
J = Float64.(cache.f.cons_jac_prototype)
131+
J = similar(cache.f.cons_jac_prototype, Float64)
132132
else
133133
J = zeros((length(cache.lcons), length(θ)))
134134
end

src/lbfgsb.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
148148
function aug_grad(G, θ)
149149
cache.f.grad(G, θ)
150150
if !isnothing(cache.f.cons_jac_prototype)
151-
J = Float64.(cache.f.cons_jac_prototype)
151+
J = similar(cache.f.cons_jac_prototype, Float64)
152152
else
153153
J = zeros((length(cache.lcons), length(θ)))
154154
end

0 commit comments

Comments
 (0)