Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
15 changes: 9 additions & 6 deletions src/trust-region/tron-trust-region.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...,
) where {T, V}
ared, pred, tr.good_grad = aredpred_common(nlp, f, f_trial, Δm, x_trial, step, tr.gt, slope)
ared, pred, tr.good_grad = aredpred_common(nlp, f, f_trial, Δm, x_trial, step, tr.gt, slope; kwargs...)
γ = f_trial - f - slope
tr.quad_min = γ <= 0 ? tr.increase_factor : max(tr.large_decrease_factor, -slope / γ / 2)
return ared, pred
Expand All @@ -108,10 +109,11 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...
) where {T, V}
Fx = similar(x_trial, nls.nls_meta.nequ)
return aredpred!(tr, nls, Fx, f, f_trial, Δm, x_trial, step, slope)
return aredpred!(tr, nls, Fx, f, f_trial, Δm, x_trial, step, slope; kwargs...)
end

function aredpred!(
Expand All @@ -123,9 +125,10 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...
) where {T, V}
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope)
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope; kwargs...)
γ = f_trial - f - slope
tr.quad_min = γ <= 0 ? tr.increase_factor : max(tr.large_decrease_factor, -slope / γ / 2)
return ared, pred
Expand Down
27 changes: 18 additions & 9 deletions src/trust-region/trust-region.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ function aredpred_common(
x_trial::V,
step::V,
g_trial::V,
slope::T,
slope::T;
use_only_objgrad::Bool = false,
) where {T, V}
absf = abs(f)
ϵ = eps(T)
Expand All @@ -54,7 +55,11 @@ function aredpred_common(
ared = f_trial - f + max(one(T), absf) * 10 * ϵ
if (abs(Δm) < 10_000 * ϵ) || (abs(ared) < 10_000 * ϵ * absf)
# correct for roundoff error
grad!(nlp, x_trial, g_trial)
if use_only_objgrad
objgrad!(nlp, x_trial, g_trial)
else
grad!(nlp, x_trial, g_trial)
end
good_grad = true
slope_trial = dot(g_trial, step)
ared = (slope_trial + slope) / 2
Expand All @@ -71,7 +76,8 @@ function aredpred_common(
x_trial::V,
step::V,
g_trial::V,
slope::T,
slope::T;
kwargs...,
) where {T, V}
absf = abs(f)
ϵ = eps(T)
Expand Down Expand Up @@ -110,9 +116,10 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...
) where {T, V}
ared, pred, tr.good_grad = aredpred_common(nlp, f, f_trial, Δm, x_trial, step, tr.gt, slope)
ared, pred, tr.good_grad = aredpred_common(nlp, f, f_trial, Δm, x_trial, step, tr.gt, slope; kwargs...)
return ared, pred
end

Expand All @@ -124,10 +131,11 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...
) where {T, V}
Fx = similar(x_trial, nls.nls_meta.nequ)
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope)
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope; kwargs...)
return ared, pred
end

Expand All @@ -140,9 +148,10 @@ function aredpred!(
Δm::T,
x_trial::V,
step::V,
slope::T,
slope::T;
kwargs...
) where {T, V}
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope)
ared, pred, tr.good_grad = aredpred_common(nls, Fx, f, f_trial, Δm, x_trial, step, tr.gt, slope; kwargs...)
return ared, pred
end

Expand Down
Loading