Skip to content

Commit 80f981f

Browse files
add sub_kwargs parameter to LMTRSolver and solve! for passing subsolver keyword arguments and prox evaluations
1 parent b2b4e33 commit 80f981f

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/LMTR_alg.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ For advanced usage, first define a solver "LMSolver" to preallocate the memory u
140140
- `β::T = 1/eps(T)`: TODO
141141
- `χ = NormLinf(1)`: norm used to define the trust-region;`
142142
- `subsolver::S = R2Solver`: subsolver used to solve the subproblem that appears at each iteration.
143+
- `sub_kwargs::NamedTuple = NamedTuple()`: a named tuple containing the keyword arguments to be sent to the subsolver. The solver will fail if invalid keyword arguments are provided to the subsolver. For example, if the subsolver is `R2Solver`, you can pass `sub_kwargs = (max_iter = 100, σmin = 1e-6,)`.
143144
144145
The algorithm stops either when `√(ξₖ/νₖ) < atol + rtol*√(ξ₀/ν₀) ` or `ξₖ < 0` and `√(-ξₖ/νₖ) < neg_tol` where ξₖ := f(xₖ) + h(xₖ) - φ(sₖ; xₖ) - ψ(sₖ; xₖ), and √(ξₖ/νₖ) is a stationarity measure.
145146
@@ -203,6 +204,7 @@ function SolverCore.solve!(
203204
γ::T = T(3),
204205
α::T = 1 / eps(T),
205206
β::T = 1 / eps(T),
207+
sub_kwargs::NamedTuple = NamedTuple(),
206208
) where {T, G, V}
207209
reset!(stats)
208210

@@ -265,6 +267,7 @@ function SolverCore.solve!(
265267

266268
local ξ1::T
267269
local ρk::T = zero(T)
270+
local prox_evals::Int = 0
268271

269272
residual!(nls, xk, Fk)
270273
jtprod_residual!(nls, xk, Fk, ∇fk)
@@ -282,6 +285,7 @@ function SolverCore.solve!(
282285
set_objective!(stats, fk + hk)
283286
set_solver_specific!(stats, :smooth_obj, fk)
284287
set_solver_specific!(stats, :nonsmooth_obj, hk)
288+
set_solver_specific!(stats, :prox_evals, prox_evals + 1)
285289

286290
φ1 = let Fk = Fk, ∇fk = ∇fk
287291
d -> dot(Fk, Fk) / 2 + dot(∇fk, d) # ∇fk = Jk^T Fk
@@ -341,22 +345,25 @@ function SolverCore.solve!(
341345
solve!(
342346
solver.subsolver,
343347
solver.subpb,
344-
solver.substats,
348+
solver.substats;
345349
x = s,
346350
atol = stats.iter == 0 ? 1.0e-5 : max(sub_atol, min(1.0e-1, ξ1 / 10)),
347351
Δk = ∆_effective / 10,
352+
sub_kwargs...,
348353
)
349354
else
350355
solve!(
351356
solver.subsolver,
352357
solver.subpb,
353-
solver.substats,
358+
solver.substats;
354359
x = s,
355360
atol = stats.iter == 0 ? 1.0e-5 : max(sub_atol, min(1.0e-1, ξ1 / 10)),
356361
ν = ν,
362+
sub_kwargs...,
357363
)
358364
end
359365

366+
prox_evals += solver.substats.iter
360367
s .= solver.substats.solution
361368

362369
sNorm = χ(s)
@@ -438,6 +445,7 @@ function SolverCore.solve!(
438445
set_solver_specific!(stats, :nonsmooth_obj, hk)
439446
set_iter!(stats, stats.iter + 1)
440447
set_time!(stats, time() - start_time)
448+
set_solver_specific!(stats, :prox_evals, prox_evals + 1)
441449

442450
ν = α * Δk / (1 + σmax^2 ** Δk + 1))
443451
@. mν∇fk = -∇fk * ν

0 commit comments

Comments
 (0)