Skip to content

Commit d7ef835

Browse files
Merge pull request #1069 from SebastianM-C/bugfix
Bugfix
2 parents ba726da + 19a3806 commit d7ef835

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

lib/OptimizationBase/ext/OptimizationZygoteExt.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ function OptimizationBase.instantiate_function(
208208
lag_extras = prepare_hessian(
209209
lagrangian, soadtype, x, Constant(one(eltype(x))),
210210
Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false))
211-
lag_hess_prototype = zeros(Bool, num_cons, length(x))
211+
lag_hess_prototype = zeros(Bool, length(x), length(x))
212212

213213
function lag_h!(H::AbstractMatrix, θ, σ, λ)
214214
if σ == zero(eltype(θ))
@@ -288,6 +288,18 @@ function OptimizationBase.instantiate_function(
288288
f, x, adtype, p, num_cons; kwargs...)
289289
end
290290

291+
function OptimizationBase.instantiate_function(
292+
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
293+
adtype::DifferentiationInterface.SecondOrder{
294+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote},
295+
num_cons = 0; kwargs...)
296+
x = cache.u0
297+
p = cache.p
298+
299+
return OptimizationBase.instantiate_function(
300+
f, x, adtype, p, num_cons; kwargs...)
301+
end
302+
291303
function OptimizationBase.instantiate_function(
292304
f::OptimizationFunction{true}, x,
293305
adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote,
@@ -575,4 +587,15 @@ function OptimizationBase.instantiate_function(
575587
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
576588
end
577589

590+
function OptimizationBase.instantiate_function(
591+
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
592+
adtype::ADTypes.AutoSparse{<:DifferentiationInterface.SecondOrder{
593+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}},
594+
num_cons = 0; kwargs...)
595+
x = cache.u0
596+
p = cache.p
597+
598+
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
599+
end
600+
578601
end

lib/OptimizationBase/src/OptimizationDIExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function instantiate_function(
198198
lag_prep = prepare_hessian(
199199
lagrangian, soadtype, x, Constant(one(eltype(x))),
200200
Constant(ones(eltype(x), num_cons)), Constant(p))
201-
lag_hess_prototype = zeros(Bool, num_cons, length(x))
201+
lag_hess_prototype = zeros(Bool, length(x), length(x))
202202

203203
function lag_h!(H::AbstractMatrix, θ, σ, λ)
204204
if σ == zero(eltype(θ))
@@ -457,7 +457,7 @@ function instantiate_function(
457457
lag_prep = prepare_hessian(
458458
lagrangian, soadtype, x, Constant(one(eltype(x))),
459459
Constant(ones(eltype(x), num_cons)), Constant(p))
460-
lag_hess_prototype = zeros(Bool, num_cons, length(x))
460+
lag_hess_prototype = zeros(Bool, length(x), length(x))
461461

462462
function lag_h!(θ, σ, λ)
463463
if σ == zero(eltype(θ))

lib/OptimizationBase/src/cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
6262
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),
6363
hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt),
6464
fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt),
65-
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
65+
cons_vjp = SciMLBase.allowsconsvjp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
6666

6767
if structural_analysis
6868
obj_res, cons_res = symify_cache(f, prob, num_cons, manifold)

lib/OptimizationBase/test/adtests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ optprob.cons_h(H3, x0)
144144
optprob.lag_h(H4, x0, σ, μ)
145145
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
146146

147+
# Test that the AD-generated lag_hess_prototype has correct dimensions
148+
@test !isnothing(optprob.lag_hess_prototype)
149+
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n
150+
151+
# Test that we can actually use it as a buffer
152+
if !isnothing(optprob.lag_hess_prototype)
153+
H_proto = similar(optprob.lag_hess_prototype, Float64)
154+
optprob.lag_h(H_proto, x0, σ, μ)
155+
@test H_proto σ * H2 + μ[1] * H3[1] rtol=1e-6
156+
end
157+
147158
G2 = Array{Float64}(undef, 2)
148159
H2 = Array{Float64}(undef, 2, 2)
149160

@@ -257,6 +268,17 @@ optprob.cons_h(H3, x0)
257268
optprob.lag_h(H4, x0, σ, μ)
258269
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
259270

271+
# Test that the AD-generated lag_hess_prototype has correct dimensions
272+
@test !isnothing(optprob.lag_hess_prototype)
273+
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n
274+
275+
# Test that we can actually use it as a buffer (this would fail with the bug)
276+
if !isnothing(optprob.lag_hess_prototype)
277+
H_proto = similar(optprob.lag_hess_prototype, Float64)
278+
optprob.lag_h(H_proto, x0, σ, μ)
279+
@test H_proto σ * H2 + μ[1] * H3[1] rtol=1e-6
280+
end
281+
260282
G2 = Array{Float64}(undef, 2)
261283
H2 = Array{Float64}(undef, 2, 2)
262284

@@ -490,6 +512,17 @@ end
490512
optprob.lag_h(H4, x0, σ, μ)
491513
@test H4σ * H1 + sum.* H3) rtol=1e-6
492514

515+
# Test that the AD-generated lag_hess_prototype has correct dimensions
516+
@test !isnothing(optprob.lag_hess_prototype)
517+
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n
518+
519+
# Test that we can actually use it as a buffer (this would fail with the bug)
520+
if !isnothing(optprob.lag_hess_prototype)
521+
H_proto = similar(optprob.lag_hess_prototype, Float64)
522+
optprob.lag_h(H_proto, x0, σ, μ)
523+
@test H_proto σ * H1 + sum.* H3) rtol=1e-6
524+
end
525+
493526
G2 = Array{Float64}(undef, 2)
494527
H2 = Array{Float64}(undef, 2, 2)
495528

0 commit comments

Comments
 (0)