Skip to content

Commit 257203e

Browse files
MaxenceGollierdpo
authored andcommitted
add qnupdate as a keywor argument
1 parent 0feb810 commit 257203e

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/R2N.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mutable struct R2NSolver{
1212
xk::V
1313
∇fk::V
1414
∇fk⁻::V
15+
y::V
1516
mν∇fk::V
1617
ψ::G
1718
xkn::V
@@ -40,6 +41,7 @@ function R2NSolver(
4041
xk = similar(x0)
4142
∇fk = similar(x0)
4243
∇fk⁻ = similar(x0)
44+
y = similar(x0)
4345
mν∇fk = similar(x0)
4446
xkn = similar(x0)
4547
s = similar(x0)
@@ -70,6 +72,7 @@ function R2NSolver(
7072
xk,
7173
∇fk,
7274
∇fk⁻,
75+
y,
7376
mν∇fk,
7477
ψ,
7578
xkn,
@@ -154,6 +157,12 @@ Notably, you can access, and modify, the following:
154157
- `stats.solver_specific[:nonsmooth_obj]`: current value of the nonsmooth part of the objective function;
155158
- `stats.status`: current status of the algorithm. Should be `:unknown` unless the algorithm has attained a stopping criterion. Changing this to anything other than `:unknown` will stop the algorithm, but you should use `:user` to properly indicate the intention;
156159
- `stats.elapsed_time`: elapsed time in seconds.
160+
Similarly to the callback, when using a quasi-Newton approximation, two functions, `qn_update_y!(nlp, solver, stats)` and `qn_copy!(nlp, solver, stats)` are called at each update of the approximation.
161+
Namely, the former computes the `y` vector for which the pair `(s, y)` is pushed into the approximation.
162+
By default, `y := ∇fk⁻ - ∇fk`.
163+
The latter allows the user to tell which values should be copied for the next iteration.
164+
By default, only the gradient is copied: `∇fk⁻ .= ∇fk`.
165+
This might be useful when using R2N in a constrained optimization context, when the gradient of the Lagrangian function is pushed at each iteration rather than the gradient of the objective function.
157166
"""
158167
function R2N(
159168
nlp::AbstractNLPModel{T, V},
@@ -200,6 +209,8 @@ function SolverCore.solve!(
200209
reg_nlp::AbstractRegularizedNLPModel{T, V},
201210
stats::GenericExecutionStats{T, V};
202211
callback = (args...) -> nothing,
212+
qn_update_y!::Function = _qn_grad_update_y!,
213+
qn_copy!::Function = _qn_grad_copy!,
203214
x::V = reg_nlp.model.meta.x0,
204215
atol::T = eps(T),
205216
rtol::T = eps(T),
@@ -283,7 +294,7 @@ function SolverCore.solve!(
283294

284295
fk = obj(nlp, xk)
285296
grad!(nlp, xk, ∇fk)
286-
∇fk⁻ .= ∇fk
297+
qn_copy!(nlp, solver, stats)
287298

288299
quasiNewtTest = isa(nlp, QuasiNewtonModel)
289300
λmax::T = T(1)
@@ -416,15 +427,14 @@ function SolverCore.solve!(
416427
grad!(nlp, xk, ∇fk)
417428

418429
if quasiNewtTest
419-
@. ∇fk⁻ = ∇fk - ∇fk⁻
420-
push!(nlp, s, ∇fk⁻)
430+
qn_update_y!(nlp, solver, stats)
431+
push!(nlp, s, solver.y)
432+
qn_copy!(nlp, solver, stats)
421433
end
422434
solver.subpb.model.B = hess_op(nlp, xk)
423435

424436
λmax, found_λ = opnorm(solver.subpb.model.B)
425437
found_λ || error("operator norm computation failed")
426-
427-
∇fk⁻ .= ∇fk
428438
end
429439

430440
if η2 ρk < Inf
@@ -500,3 +510,11 @@ function SolverCore.solve!(
500510
set_residuals!(stats, zero(eltype(xk)), sqrt_ξ1_νInv)
501511
return stats
502512
end
513+
514+
function _qn_grad_update_y!(nlp::AbstractNLPModel{T, V}, solver::R2NSolver{T, G, V}, stats::GenericExecutionStats) where{T, V, G}
515+
@. solver.y = solver.∇fk - solver.∇fk⁻
516+
end
517+
518+
function _qn_grad_copy!(nlp::AbstractNLPModel{T, V}, solver::R2NSolver{T, G, V}, stats::GenericExecutionStats) where{T, V, G}
519+
solver.∇fk⁻ .= solver.∇fk
520+
end

0 commit comments

Comments
 (0)