Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions lib/BracketingNonlinearSolve/src/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10)
scaled_k1 < 0 && error("Hyper-parameter κ₁ should not be negative")
n0 < 0 && error("Hyper-parameter n₀ should not be negative")
if k2 < 1 || k2 > (1.5 + sqrt(5) / 2)
if !(1 <= k2 <= 1.5 + sqrt(5) / 2)
throw(ArgumentError("Hyper-parameter κ₂ should be between 1 and 1 + ϕ where \
ϕ ≈ 1.618... is the golden ratio"))
end
Expand All @@ -63,7 +63,7 @@ function CommonSolve.solve(
@assert !SciMLBase.isinplace(prob) "`ITP` only supports out-of-place problems."

f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
left, right = minmax(promote(prob.tspan...)...)
fl, fr = f(left), f(right)

abstol = NonlinearSolveBase.get_tolerance(
Expand Down Expand Up @@ -93,45 +93,34 @@ function CommonSolve.solve(

ϵ = abstol
k2 = alg.k2
k1 = alg.scaled_k1 * abs(right - left)^(1 - k2)
span = right - left
k1 = alg.scaled_k1 * span^(1 - k2) # k1 > 0
n0 = alg.n0
mid = (left + right) / 2
x_f = left + (right - left) * (fl / (fl - fr))
xt = left
xp = left
r = zero(left) # minmax radius
δ = zero(left) # truncation error
σ = one(mid)
n_h = exponent(abs(right - left) / (2 * ϵ))
n_h = exponent(span / (2 * ϵ))
ϵ_s = ϵ * exp2(n_h + n0)
T0 = zero(fl)

i = 1
while i ≤ maxiters
span = abs(right - left)
span = right - left
mid = (left + right) / 2
r = ϵ_s - (span / 2)
δ = k1 * span^k2

x_f = left + (right - left) * (fl / (fl - fr)) # Interpolation Step
x_f = left + span * (fl / (fl - fr)) # Interpolation Step

δ = max(k1 * span^k2, eps(x_f))
diff = mid - x_f
σ = sign(diff)
xt = ifelse(δ ≤ diff, x_f + σ * δ, mid) # Truncation Step

xp = ifelse(abs(xt - mid) ≤ r, xt, mid - σ * r) # Projection Step
xt = ifelse(δ ≤ abs(diff), x_f + copysign(δ, diff), mid) # Truncation Step

if abs((left - right) / 2) < ϵ
xp = ifelse(abs(xt - mid) ≤ r, xt, mid - copysign(r, diff)) # Projection Step
if span < 2ϵ
return SciMLBase.build_solution(
prob, alg, xt, f(xt); retcode = ReturnCode.Success, left, right
)
end

# update
tmin, tmax = minmax(xt, xp)
xp ≥ tmax && (xp = prevfloat(tmax))
xp ≤ tmin && (xp = nextfloat(tmin))
yp = f(xp)
yps = yp * sign(fr)
T0 = zero(yps)
if yps > T0
right, fr = xp, yp
elseif yps < T0
Expand All @@ -143,10 +132,9 @@ function CommonSolve.solve(
end

i += 1
mid = (left + right) / 2
ϵ_s /= 2

if Impl.nextfloat_tdir(left, prob.tspan...) == right
if nextfloat(left) == right
return SciMLBase.build_solution(
prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right
)
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/test/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end
ϵ = eps(Float64) # least possible tol for all methods

@testset for alg in (Bisection(), Falsi(), ITP(), nothing)
@testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7]
@testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6]
sol = solve(prob, alg; abstol)
result_tol = abs(sol.u - sqrt(2))
@test result_tol < abstol
Expand Down
Loading