Skip to content

Commit 458aae8

Browse files
improve type stability in LM
1 parent 02b6cce commit 458aae8

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/LM_alg.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ mutable struct LMSolver{
99
T <: Real,
1010
G <: ShiftedProximableFunction,
1111
V <: AbstractVector{T},
12-
M <: AbstractLinearOperator{T},
1312
ST <: AbstractOptimizationSolver,
1413
PB <: AbstractRegularizedNLPModel,
1514
} <: AbstractOptimizationSolver
@@ -18,7 +17,6 @@ mutable struct LMSolver{
1817
mν∇fk::V
1918
Fk::V
2019
Fkn::V
21-
Jk::M
2220
ψ::G
2321
xkn::V
2422
s::V
@@ -45,7 +43,6 @@ function LMSolver(
4543
mν∇fk = similar(x0)
4644
Fk = similar(x0, reg_nls.model.nls_meta.nequ)
4745
Fkn = similar(Fk)
48-
Jk = jac_op_residual(reg_nls.model, xk)
4946
xkn = similar(x0)
5047
s = similar(x0)
5148
has_bnds = any(l_bound .!= T(-Inf)) || any(u_bound .!= T(Inf)) || subsolver == TRDHSolver
@@ -63,18 +60,24 @@ function LMSolver(
6360
has_bnds ? shifted(reg_nls.h, xk, l_bound_m_x, u_bound_m_x, reg_nls.selected) :
6461
shifted(reg_nls.h, xk)
6562

66-
sub_nlp = LMModel(Jk, Fk, T(1), x0)
63+
jprod! = let nls = reg_nls.model
64+
(x, v, Jv) -> jprod_residual!(nls, x, v, Jv)
65+
end
66+
jt_prod! = let nls = reg_nls.model
67+
(x, v, Jtv) -> jtprod_residual!(nls, x, v, Jtv)
68+
end
69+
70+
sub_nlp = LMModel(jprod!, jt_prod!, Fk, T(1), xk)
6771
subpb = RegularizedNLPModel(sub_nlp, ψ)
6872
substats = RegularizedExecutionStats(subpb)
6973
subsolver = subsolver(subpb)
7074

71-
return LMSolver(
75+
return LMSolver{T, typeof(ψ), V, typeof(subsolver), typeof(subpb)}(
7276
xk,
7377
∇fk,
7478
mν∇fk,
7579
Fk,
7680
Fkn,
77-
Jk,
7881
ψ,
7982
xkn,
8083
s,
@@ -124,13 +127,11 @@ function SolverCore.solve!(
124127

125128
Fk = solver.Fk
126129
Fkn = solver.Fkn
127-
Jk = solver.Jk
128130
∇fk = solver.∇fk
129131
mν∇fk = solver.mν∇fk
130132
ψ = solver.ψ
131133
xkn = solver.xkn
132134
s = solver.s
133-
134135
has_bnds = solver.has_bnds
135136
if has_bnds
136137
l_bound = solver.l_bound
@@ -144,7 +145,7 @@ function SolverCore.solve!(
144145
hk = @views h(xk[selected])
145146
if hk == Inf
146147
verbose > 0 && @info "LM: finding initial guess where nonsmooth term is finite"
147-
prox!(xk, h, xk, one(eltype(x0)))
148+
prox!(xk, h, xk, one(T))
148149
hk = @views h(xk[selected])
149150
hk < Inf || error("prox computation must be erroneous")
150151
verbose > 0 && @debug "LM: found point where h has value" hk
@@ -174,11 +175,12 @@ function SolverCore.solve!(
174175
local ρk::T = zero(T)
175176

176177
residual!(nls, xk, Fk)
177-
Jk = jac_op_residual(nls, xk)
178+
#solver.subpb.model.J = jac_op_residual!(nls, xk, Jv, Jtv)
178179
jtprod_residual!(nls, xk, Fk, ∇fk)
179180
fk = dot(Fk, Fk) / 2
180181

181-
σmax, found_σ = opnorm(Jk)
182+
#σmax, found_σ = opnorm(solver.subpb.model.J)
183+
σmax, found_σ = one(T), true
182184
found_σ || error("operator norm computation failed")
183185
ν = θ / (σmax^2 + σk) # ‖J'J + σₖ I‖ = ‖J‖² + σₖ
184186
sqrt_ξ1_νInv = one(T)
@@ -306,12 +308,13 @@ function SolverCore.solve!(
306308

307309
# update gradient & Hessian
308310
shift!(ψ, xk)
309-
Jk = jac_op_residual(nls, xk)
311+
#solver.subpb.model.J = jac_op_residual!(nls, xk, Jv, Jtv)
310312
jtprod_residual!(nls, xk, Fk, ∇fk)
311313

312314
# update opnorm if not linear least squares
313315
if nonlinear == true
314-
σmax, found_σ = opnorm(Jk)
316+
#σmax, found_σ = opnorm(solver.subpb.model.J)
317+
σmax, found_σ = one(T), true
315318
found_σ || error("operator norm computation failed")
316319
end
317320
end

0 commit comments

Comments
 (0)