Skip to content

Commit a6a05db

Browse files
add sub_kwargs parameter to LMSolver and solve! for subsolver (#245)
* add sub_kwargs parameter to LMSolver and solve! for subsolver keyword arguments * add sub_kwargs parameter to LMTRSolver and solve! for passing subsolver keyword arguments and prox evaluations
1 parent 4344f74 commit a6a05db

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/LMTR_alg.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ For advanced usage, first define a solver "LMSolver" to preallocate the memory u
140140
- `β::T = 1/eps(T)`: TODO
141141
- `χ = NormLinf(1)`: norm used to define the trust-region;`
142142
- `subsolver::S = R2Solver`: subsolver used to solve the subproblem that appears at each iteration.
143+
- `sub_kwargs::NamedTuple = NamedTuple()`: a named tuple containing the keyword arguments to be sent to the subsolver. The solver will fail if invalid keyword arguments are provided to the subsolver. For example, if the subsolver is `R2Solver`, you can pass `sub_kwargs = (max_iter = 100, σmin = 1e-6,)`.
143144
144145
The algorithm stops either when `√(ξₖ/νₖ) < atol + rtol*√(ξ₀/ν₀) ` or `ξₖ < 0` and `√(-ξₖ/νₖ) < neg_tol` where ξₖ := f(xₖ) + h(xₖ) - φ(sₖ; xₖ) - ψ(sₖ; xₖ), and √(ξₖ/νₖ) is a stationarity measure.
145146
@@ -203,6 +204,7 @@ function SolverCore.solve!(
203204
γ::T = T(3),
204205
α::T = 1 / eps(T),
205206
β::T = 1 / eps(T),
207+
sub_kwargs::NamedTuple = NamedTuple(),
206208
) where {T, G, V}
207209
reset!(stats)
208210

@@ -265,6 +267,7 @@ function SolverCore.solve!(
265267

266268
local ξ1::T
267269
local ρk::T = zero(T)
270+
local prox_evals::Int = 0
268271

269272
residual!(nls, xk, Fk)
270273
jtprod_residual!(nls, xk, Fk, ∇fk)
@@ -282,6 +285,7 @@ function SolverCore.solve!(
282285
set_objective!(stats, fk + hk)
283286
set_solver_specific!(stats, :smooth_obj, fk)
284287
set_solver_specific!(stats, :nonsmooth_obj, hk)
288+
set_solver_specific!(stats, :prox_evals, prox_evals + 1)
285289

286290
φ1 = let Fk = Fk, ∇fk = ∇fk
287291
d -> dot(Fk, Fk) / 2 + dot(∇fk, d) # ∇fk = Jk^T Fk
@@ -341,22 +345,25 @@ function SolverCore.solve!(
341345
solve!(
342346
solver.subsolver,
343347
solver.subpb,
344-
solver.substats,
348+
solver.substats;
345349
x = s,
346350
atol = stats.iter == 0 ? 1.0e-5 : max(sub_atol, min(1.0e-1, ξ1 / 10)),
347351
Δk = ∆_effective / 10,
352+
sub_kwargs...,
348353
)
349354
else
350355
solve!(
351356
solver.subsolver,
352357
solver.subpb,
353-
solver.substats,
358+
solver.substats;
354359
x = s,
355360
atol = stats.iter == 0 ? 1.0e-5 : max(sub_atol, min(1.0e-1, ξ1 / 10)),
356361
ν = ν,
362+
sub_kwargs...,
357363
)
358364
end
359365

366+
prox_evals += solver.substats.iter
360367
s .= solver.substats.solution
361368

362369
sNorm = χ(s)
@@ -438,6 +445,7 @@ function SolverCore.solve!(
438445
set_solver_specific!(stats, :nonsmooth_obj, hk)
439446
set_iter!(stats, stats.iter + 1)
440447
set_time!(stats, time() - start_time)
448+
set_solver_specific!(stats, :prox_evals, prox_evals + 1)
441449

442450
ν = α * Δk / (1 + σmax^2 ** Δk + 1))
443451
@. mν∇fk = -∇fk * ν

src/LM_alg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ For advanced usage, first define a solver "LMSolver" to preallocate the memory u
140140
- `θ::T = 1/(1 + eps(T)^(1 / 5))`: is the model decrease fraction with respect to the decrease of the Cauchy model;
141141
- `m_monotone::Int = 1`: monotonicity parameter. By default, LM is monotone but the non-monotone variant will be used if `m_monotone > 1`;
142142
- `subsolver = R2Solver`: the solver used to solve the subproblems.
143+
- `sub_kwargs::NamedTuple = NamedTuple()`: a named tuple containing the keyword arguments to be sent to the subsolver. The solver will fail if invalid keyword arguments are provided to the subsolver. For example, if the subsolver is `R2Solver`, you can pass `sub_kwargs = (max_iter = 100, σmin = 1e-6,)`.
143144
144145
The algorithm stops either when `√(ξₖ/νₖ) < atol + rtol*√(ξ₀/ν₀) ` or `ξₖ < 0` and `√(-ξₖ/νₖ) < neg_tol` where ξₖ := f(xₖ) + h(xₖ) - φ(sₖ; xₖ) - ψ(sₖ; xₖ), and √(ξₖ/νₖ) is a stationarity measure.
145146
@@ -202,6 +203,7 @@ function SolverCore.solve!(
202203
η2::T = T(0.9),
203204
γ::T = T(3),
204205
θ::T = 1/(1 + eps(T)^(1 / 5)),
206+
sub_kwargs::NamedTuple = NamedTuple(),
205207
) where {T, V, G}
206208
reset!(stats)
207209

@@ -334,15 +336,16 @@ function SolverCore.solve!(
334336
solver.subpb.model.σ = σk
335337
isa(solver.subsolver, R2DHSolver) && (solver.subsolver.D.d[1] = 1/ν)
336338
if isa(solver.subsolver, R2Solver) #FIXME
337-
solve!(solver.subsolver, solver.subpb, solver.substats, x = s, atol = sub_atol, ν = ν)
339+
solve!(solver.subsolver, solver.subpb, solver.substats; x = s, atol = sub_atol, ν = ν, sub_kwargs...)
338340
else
339341
solve!(
340342
solver.subsolver,
341343
solver.subpb,
342-
solver.substats,
344+
solver.substats;
343345
x = s,
344346
atol = sub_atol,
345347
σk = σk, #FIXME
348+
sub_kwargs...,
346349
)
347350
end
348351

0 commit comments

Comments
 (0)