From 2bc835a6c11ebe3acbc97aaa086af9a8f9841303 Mon Sep 17 00:00:00 2001 From: Maxence Gollier Date: Tue, 30 Sep 2025 17:55:09 -0400 Subject: [PATCH] add skip_sigma kwarg --- src/LMModel.jl | 3 ++- src/LM_alg.jl | 2 +- src/R2N.jl | 2 +- src/R2NModel.jl | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/LMModel.jl b/src/LMModel.jl index 845c1680..3af161ae 100644 --- a/src/LMModel.jl +++ b/src/LMModel.jl @@ -39,11 +39,12 @@ function LMModel(J::Jac, F::V, σ::T, xk::V) where {T, V, Jac} return LMModel(J, F, v, xk, σ, meta, Counters()) end -function NLPModels.obj(nlp::LMModel, x::AbstractVector{T}) where {T} +function NLPModels.obj(nlp::LMModel, x::AbstractVector{T}; skip_sigma = false) where {T} @lencheck nlp.meta.nvar x increment!(nlp, :neval_obj) mul!(nlp.v, nlp.J, x) nlp.v .+= nlp.F + skip_sigma && return dot(nlp.v, nlp.v)/2 return (dot(nlp.v, nlp.v) + nlp.σ * dot(x, x)) / 2 end diff --git a/src/LM_alg.jl b/src/LM_alg.jl index aefbce48..11ba77cf 100644 --- a/src/LM_alg.jl +++ b/src/LM_alg.jl @@ -300,7 +300,7 @@ function SolverCore.solve!( end mk = let ψ = ψ, solver = solver - d -> obj(solver.subpb.model, d) + ψ(d) + d -> obj(solver.subpb.model, d, skip_sigma = true) + ψ(d) end prox!(s, ψ, mν∇fk, ν) diff --git a/src/R2N.jl b/src/R2N.jl index 26a8eea0..b5b768cb 100644 --- a/src/R2N.jl +++ b/src/R2N.jl @@ -327,7 +327,7 @@ function SolverCore.solve!( end mk = let ψ = ψ, solver = solver - d -> obj(solver.subpb.model, d) + ψ(d)::T + d -> obj(solver.subpb.model, d, skip_sigma = true) + ψ(d)::T end prox!(s1, ψ, mν∇fk, ν₁) diff --git a/src/R2NModel.jl b/src/R2NModel.jl index a4b2f990..3ce4d779 100644 --- a/src/R2NModel.jl +++ b/src/R2NModel.jl @@ -34,10 +34,11 @@ function R2NModel(B::G, ∇f::V, σ::T, x0::V) where {T, V, G} return R2NModel(B::G, ∇f::V, v::V, σ::T, meta, Counters()) end -function NLPModels.obj(nlp::R2NModel, x::AbstractVector) +function NLPModels.obj(nlp::R2NModel, x::AbstractVector; skip_sigma = false) @lencheck nlp.meta.nvar x increment!(nlp, :neval_obj) mul!(nlp.v, nlp.B, x) + skip_sigma && return dot(nlp.v, x)/2 + dot(nlp.∇f, x) return dot(nlp.v, x)/2 + dot(nlp.∇f, x) + nlp.σ * dot(x, x) / 2 end