-
Notifications
You must be signed in to change notification settings - Fork 10
LMModel → LLSModels and R2NModel → QuadraticModels #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
dbbce49
3774c3d
685dabb
d984a62
8473946
9f9322f
4a0d87c
8ed914e
b7d2979
b06e017
dd6fc3b
37e9beb
8357983
589f75b
0f9f29e
13ce83c
3cef1e2
4e1e1f9
7c7aeaf
f2f89ed
a434d5a
2ea55a4
e3bd03b
48811bd
92c7efb
109c004
b5e1c58
deef93e
fb0fbb3
2d08604
8dd1b2a
f5d105e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,42 @@ | ||
| export R2N, R2NSolver, solve! | ||
|
|
||
| import SolverCore.solve! | ||
| using LinearAlgebra | ||
| using LinearOperators | ||
|
|
||
| # A small mutable wrapper that represents B + sigma*I without allocating a new | ||
| # LinearOperator every time sigma or B changes. It provides mul! methods so it | ||
| # can be used where a LinearOperator is expected. | ||
| mutable struct ShiftedHessian{T} | ||
| B::Any | ||
| sigma::T | ||
| end | ||
|
|
||
| Base.size(op::ShiftedHessian) = size(op.B) | ||
| Base.eltype(op::ShiftedHessian) = eltype(op.B) | ||
|
|
||
| import LinearAlgebra: adjoint | ||
| function adjoint(op::ShiftedHessian{T}) where T | ||
| return LinearAlgebra.Adjoint(op) | ||
| end | ||
|
|
||
| function LinearAlgebra.mul!(y::AbstractVector{T}, op::ShiftedHessian{T}, x::AbstractVector{T}) where T | ||
| mul!(y, op.B, x) | ||
| @inbounds for i in eachindex(y) | ||
| y[i] += op.sigma * x[i] | ||
| end | ||
| return y | ||
| end | ||
|
|
||
| function LinearAlgebra.mul!(y::AbstractVector{T}, opAd::Adjoint{<:Any,ShiftedHessian{T}}, x::AbstractVector{T}) where T | ||
| # Use the adjoint of the underlying operator and add sigma*x | ||
| mul!(y, adjoint(opAd.parent.B), x) | ||
| @inbounds for i in eachindex(y) | ||
| y[i] += opAd.parent.sigma * x[i] | ||
| end | ||
| return y | ||
| end | ||
|
|
||
|
|
||
| mutable struct R2NSolver{ | ||
| T <: Real, | ||
|
|
@@ -28,6 +64,12 @@ mutable struct R2NSolver{ | |
| subsolver::ST | ||
| subpb::PB | ||
| substats::GenericExecutionStats{T, V, V, T} | ||
| # Pre-allocated components for QuadraticModel recreation | ||
| Id::LinearOperator # Identity operator | ||
| x0_quad::V # Zero vector for QuadraticModel x0 | ||
| reg_hess::LinearOperator # regularized Hessian operator | ||
| reg_hess_wrapper::ShiftedHessian{T} # mutable wrapper (B, sigma) | ||
| reg_hess_op::LinearOperator # LinearOperator that captures the wrapper | ||
| end | ||
|
|
||
| function R2NSolver( | ||
|
|
@@ -68,7 +110,25 @@ function R2NSolver( | |
| shifted(reg_nlp.h, xk) | ||
|
|
||
| Bk = hess_op(reg_nlp.model, x0) | ||
| sub_nlp = R2NModel(Bk, ∇fk, T(1), x0) | ||
| # Create quadratic model: min ∇f^T s + 1/2 s^T B s + σ/2 ||s||^2 | ||
| # QuadraticModel represents: min c^T x + 1/2 x^T H x + c0 | ||
| # So we need c = ∇fk, H = Bk + σI, c0 = 0 | ||
| σ = T(1) | ||
| n = length(∇fk) | ||
| Id = opEye(T, n) # Identity operator | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x0_quad = zeros(T, n) # Pre-allocate x0 for QuadraticModel | ||
| # Create a mutable wrapper around the Hessian so we can update sigma/B without | ||
| # allocating a new operator every iteration. | ||
| reg_hess_wrapper = ShiftedHessian{T}(Bk, T(1)) | ||
| # Create a LinearOperator that calls mul! on the wrapper. This operator is | ||
| # allocated once and keeps a reference to the mutable wrapper, so future | ||
| # updates can mutate the wrapper without reallocating the operator. | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| reg_hess_op = LinearOperator{T}(n, n, false, false, | ||
| (y, x) -> mul!(y, reg_hess_wrapper, x), | ||
| (y, x) -> mul!(y, adjoint(reg_hess_wrapper), x), | ||
| (y, x) -> mul!(y, adjoint(reg_hess_wrapper), x), | ||
| ) | ||
| sub_nlp = QuadraticModel(∇fk, reg_hess_op, c0 = zero(T), x0 = x0_quad, name = "R2N-subproblem") | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| subpb = RegularizedNLPModel(sub_nlp, ψ) | ||
| substats = RegularizedExecutionStats(subpb) | ||
| subsolver = subsolver(subpb) | ||
|
|
@@ -93,6 +153,10 @@ function R2NSolver( | |
| subsolver, | ||
| subpb, | ||
| substats, | ||
| Id, | ||
| x0_quad, | ||
| reg_hess_op, | ||
| reg_hess_wrapper, | ||
| ) | ||
|
Comment on lines
117
to
120
|
||
| end | ||
|
|
||
|
|
@@ -195,6 +259,14 @@ function R2N(reg_nlp::AbstractRegularizedNLPModel; kwargs...) | |
| return stats | ||
| end | ||
|
|
||
| # Helper function to update QuadraticModel in-place to avoid allocations | ||
| function update_quadratic_model!(qm::QuadraticModel, c::AbstractVector, H=nothing) | ||
arnavk23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Update gradient only; Hessian wrapper should be mutated by the caller | ||
| copyto!(qm.data.c, c) | ||
| qm.counters.neval_hess = 0 | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return qm | ||
| end | ||
|
Comment on lines
+254
to
+255
|
||
|
|
||
| function SolverCore.solve!( | ||
| solver::R2NSolver{T, G, V}, | ||
| reg_nlp::AbstractRegularizedNLPModel{T, V}, | ||
|
|
@@ -292,12 +364,18 @@ function SolverCore.solve!( | |
| quasiNewtTest = isa(nlp, QuasiNewtonModel) | ||
| λmax::T = T(1) | ||
| found_λ = true | ||
| solver.subpb.model.B = hess_op(nlp, xk) | ||
| # Update the Hessian and update the QuadraticModel | ||
| Bk_new = hess_op(nlp, xk) | ||
| σ = T(1) | ||
| # Update the existing ShiftedHessian wrapper in-place to avoid allocations | ||
| solver.reg_hess_wrapper.B = Bk_new | ||
| solver.reg_hess_wrapper.sigma = σ | ||
| update_quadratic_model!(solver.subpb.model, solver.∇fk) | ||
|
|
||
| if opnorm_maxiter ≤ 0 | ||
| λmax, found_λ = opnorm(solver.subpb.model.B) | ||
| λmax, found_λ = opnorm(Bk_new) | ||
| else | ||
| λmax = power_method!(solver.subpb.model.B, solver.v0, solver.subpb.model.v, opnorm_maxiter) | ||
| λmax = power_method!(Bk_new, solver.v0, solver.subpb.model.data.v, opnorm_maxiter) | ||
| end | ||
| found_λ || error("operator norm computation failed") | ||
|
|
||
|
|
@@ -327,7 +405,7 @@ function SolverCore.solve!( | |
| end | ||
|
|
||
| mk = let ψ = ψ, solver = solver | ||
| d -> obj(solver.subpb.model, d, skip_sigma = true) + ψ(d)::T | ||
| d -> obj_skip_sigma(solver.subpb.model, d) + ψ(d)::T | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| prox!(s1, ψ, mν∇fk, ν₁) | ||
|
|
@@ -361,7 +439,12 @@ function SolverCore.solve!( | |
| while !done | ||
| sub_atol = stats.iter == 0 ? 1.0e-3 : min(sqrt_ξ1_νInv ^ (1.5), sqrt_ξ1_νInv * 1e-3) | ||
|
|
||
| solver.subpb.model.σ = σk | ||
| # Update QuadraticModel with updated regularization parameter | ||
| Bk_current = hess_op(nlp, xk) | ||
| # mutate wrapper in-place | ||
| solver.reg_hess_wrapper.B = Bk_current | ||
| solver.reg_hess_wrapper.sigma = σk | ||
| update_quadratic_model!(solver.subpb.model, solver.∇fk) | ||
| isa(solver.subsolver, R2DHSolver) && (solver.subsolver.D.d[1] = 1/ν₁) | ||
| if isa(solver.subsolver, R2Solver) #FIXME | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| solve!( | ||
|
|
@@ -445,12 +528,17 @@ function SolverCore.solve!( | |
| push!(nlp, s, solver.y) | ||
| qn_copy!(nlp, solver, stats) | ||
| end | ||
| solver.subpb.model.B = hess_op(nlp, xk) | ||
| # Update the Hessian and update the QuadraticModel | ||
| Bk_new = hess_op(nlp, xk) | ||
| σ = T(1) | ||
| solver.reg_hess_wrapper.B = Bk_new | ||
| solver.reg_hess_wrapper.sigma = σ | ||
| update_quadratic_model!(solver.subpb.model, solver.∇fk) | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if opnorm_maxiter ≤ 0 | ||
| λmax, found_λ = opnorm(solver.subpb.model.B) | ||
| λmax, found_λ = opnorm(Bk_new) | ||
| else | ||
| λmax = power_method!(solver.subpb.model.B, solver.v0, solver.subpb.model.v, opnorm_maxiter) | ||
| λmax = power_method!(Bk_new, solver.v0, solver.subpb.model.data.v, opnorm_maxiter) | ||
| end | ||
| found_λ || error("operator norm computation failed") | ||
| end | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.