Skip to content

Commit 01095ae

Browse files
SebastianM-Cclaude
andcommitted
Add cons_h_weighted! to handle σ=0 in Lagrangian Hessian
When computing the Lagrangian Hessian with lag_h!, the case σ=0 requires special handling since it reduces to just the weighted sum of constraint Hessians (Σᵢ λᵢ∇²cᵢ) without the objective contribution. Previously, this case would fail when cons_h was not explicitly requested but lag_h was, because the constraint Hessian preparations were not created. This commit: - Always creates constraint Hessian preparations when either cons_h or lag_h is true - Adds cons_h_weighted!(H, θ, λ) function to compute the weighted sum directly into H - Updates lag_h! to use cons_h_weighted! when σ=0 This fixes the edge case in OptimizationMadNLP where the solver could hit σ=0 during iterations, particularly with exact Hessian and sparse KKT systems. Applies to both OptimizationDIExt and OptimizationZygoteExt. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 6c9977e commit 01095ae

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

lib/OptimizationBase/ext/OptimizationZygoteExt.jl

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,54 @@ function OptimizationBase.instantiate_function(
186186

187187
conshess_sparsity = f.cons_hess_prototype
188188
conshess_colors = f.cons_hess_colorvec
189-
if cons !== nothing && cons_h == true && f.cons_h === nothing
189+
190+
# Prepare constraint Hessian preparations if needed by lag_h or cons_h
191+
if cons !== nothing && f.cons_h === nothing && (cons_h == true || lag_h == true)
190192
prep_cons_hess = [prepare_hessian(
191193
cons_oop, soadtype, x, Constant(i), strict = Val(false))
192194
for i in 1:num_cons]
195+
else
196+
prep_cons_hess = nothing
197+
end
198+
199+
# Generate cons_h! functions
200+
if cons !== nothing && f.cons_h === nothing && prep_cons_hess !== nothing
201+
# Standard cons_h! that returns array of matrices
202+
if cons_h == true
203+
cons_h! = function (H, θ)
204+
for i in 1:num_cons
205+
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
206+
end
207+
end
208+
else
209+
cons_h! = nothing
210+
end
211+
212+
# Weighted sum dispatch for cons_h! (always created if prep_cons_hess exists)
213+
# This is used by lag_h! when σ=0
214+
cons_h_weighted! = function (H::AbstractMatrix, θ, λ)
215+
# Compute weighted sum: H = Σᵢ λᵢ∇²cᵢ
216+
H .= zero(eltype(H))
217+
218+
# Create a single temporary matrix to reuse for all constraints
219+
Hi = similar(H)
193220

194-
function cons_h!(H, θ)
195221
for i in 1:num_cons
196-
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
222+
if λ[i] != zero(eltype(λ))
223+
# Compute constraint's Hessian into temporary matrix
224+
hessian!(cons_oop, Hi, prep_cons_hess[i], soadtype, θ, Constant(i))
225+
# Add weighted Hessian to result using in-place operation
226+
# H += λ[i] * Hi
227+
@. H += λ[i] * Hi
228+
end
197229
end
198230
end
199231
elseif cons !== nothing && cons_h == true
200232
cons_h! = (res, θ) -> f.cons_h(res, θ, p)
233+
cons_h_weighted! = nothing
201234
else
202235
cons_h! = nothing
236+
cons_h_weighted! = nothing
203237
end
204238

205239
lag_hess_prototype = f.lag_hess_prototype
@@ -212,8 +246,8 @@ function OptimizationBase.instantiate_function(
212246

213247
function lag_h!(H::AbstractMatrix, θ, σ, λ)
214248
if σ == zero(eltype(θ))
215-
cons_h!(H, θ)
216-
H *= λ
249+
# When σ=0, use the weighted sum function
250+
cons_h_weighted!(H, θ, λ)
217251
else
218252
hessian!(lagrangian, H, lag_extras, soadtype, θ,
219253
Constant(σ), Constant(λ), Constant(p))
@@ -512,8 +546,8 @@ function OptimizationBase.instantiate_function(
512546

513547
function lag_h!(H::AbstractMatrix, θ, σ, λ)
514548
if σ == zero(eltype(θ))
515-
cons_h!(H, θ)
516-
H *= λ
549+
# When σ=0, use the weighted sum function
550+
cons_h_weighted!(H, θ, λ)
517551
else
518552
hessian!(lagrangian, H, lag_extras, soadtype, θ,
519553
Constant(σ), Constant(λ), Constant(p))

lib/OptimizationBase/src/OptimizationDIExt.jl

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,53 @@ function instantiate_function(
177177

178178
conshess_sparsity = f.cons_hess_prototype
179179
conshess_colors = f.cons_hess_colorvec
180-
# Generate cons_h! if explicitly requested OR if lag_h needs it
180+
181+
# Prepare constraint Hessian preparations if needed by lag_h or cons_h
181182
if f.cons !== nothing && f.cons_h === nothing && (cons_h == true || lag_h == true)
182183
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
183184
for i in 1:num_cons]
185+
else
186+
prep_cons_hess = nothing
187+
end
188+
189+
# Generate cons_h! functions
190+
if f.cons !== nothing && f.cons_h === nothing && prep_cons_hess !== nothing
191+
# Standard cons_h! that returns array of matrices
192+
if cons_h == true
193+
cons_h! = function (H, θ)
194+
for i in 1:num_cons
195+
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
196+
end
197+
end
198+
else
199+
cons_h! = nothing
200+
end
201+
202+
# Weighted sum dispatch for cons_h! (always created if prep_cons_hess exists)
203+
# This is used by lag_h! when σ=0
204+
cons_h_weighted! = function (H::AbstractMatrix, θ, λ)
205+
# Compute weighted sum: H = Σᵢ λᵢ∇²cᵢ
206+
H .= zero(eltype(H))
207+
208+
# Create a single temporary matrix to reuse for all constraints
209+
Hi = similar(H)
184210

185-
function cons_h!(H, θ)
186211
for i in 1:num_cons
187-
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
212+
if λ[i] != zero(eltype(λ))
213+
# Compute constraint's Hessian into temporary matrix
214+
hessian!(cons_oop, Hi, prep_cons_hess[i], soadtype, θ, Constant(i))
215+
# Add weighted Hessian to result using in-place operation
216+
# H += λ[i] * Hi
217+
@. H += λ[i] * Hi
218+
end
188219
end
189220
end
190-
elseif (cons_h == true || lag_h == true) && f.cons !== nothing
221+
elseif cons_h == true && f.cons !== nothing
191222
cons_h! = (res, θ) -> f.cons_h(res, θ, p)
223+
cons_h_weighted! = nothing
192224
else
193225
cons_h! = nothing
226+
cons_h_weighted! = nothing
194227
end
195228

196229
lag_hess_prototype = f.lag_hess_prototype
@@ -203,8 +236,8 @@ function instantiate_function(
203236

204237
function lag_h!(H::AbstractMatrix, θ, σ, λ)
205238
if σ == zero(eltype(θ))
206-
cons_h!(H, θ)
207-
H *= λ
239+
# When σ=0, use the weighted sum function
240+
cons_h_weighted!(H, θ, λ)
208241
else
209242
hessian!(lagrangian, H, lag_prep, soadtype, θ,
210243
Constant(σ), Constant(λ), Constant(p))

0 commit comments

Comments
 (0)