diff --git a/lib/BracketingNonlinearSolve/src/itp.jl b/lib/BracketingNonlinearSolve/src/itp.jl index 153981faf..cbf5818bf 100644 --- a/lib/BracketingNonlinearSolve/src/itp.jl +++ b/lib/BracketingNonlinearSolve/src/itp.jl @@ -49,7 +49,7 @@ end function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10) scaled_k1 < 0 && error("Hyper-parameter κ₁ should not be negative") n0 < 0 && error("Hyper-parameter n₀ should not be negative") - if k2 < 1 || k2 > (1.5 + sqrt(5) / 2) + if !(1 <= k2 <= 1.5 + sqrt(5) / 2) throw(ArgumentError("Hyper-parameter κ₂ should be between 1 and 1 + ϕ where \ ϕ ≈ 1.618... is the golden ratio")) end @@ -63,7 +63,7 @@ function CommonSolve.solve( @assert !SciMLBase.isinplace(prob) "`ITP` only supports out-of-place problems." f = Base.Fix2(prob.f, prob.p) - left, right = prob.tspan + left, right = minmax(promote(prob.tspan...)...) fl, fr = f(left), f(right) abstol = NonlinearSolveBase.get_tolerance( @@ -93,45 +93,34 @@ function CommonSolve.solve( ϵ = abstol k2 = alg.k2 - k1 = alg.scaled_k1 * abs(right - left)^(1 - k2) + span = right - left + k1 = alg.scaled_k1 * span^(1 - k2) # k1 > 0 n0 = alg.n0 - mid = (left + right) / 2 - x_f = left + (right - left) * (fl / (fl - fr)) - xt = left - xp = left - r = zero(left) # minmax radius - δ = zero(left) # truncation error - σ = one(mid) - n_h = exponent(abs(right - left) / (2 * ϵ)) + n_h = exponent(span / (2 * ϵ)) ϵ_s = ϵ * exp2(n_h + n0) + T0 = zero(fl) i = 1 while i ≤ maxiters - span = abs(right - left) + span = right - left + mid = (left + right) / 2 r = ϵ_s - (span / 2) - δ = k1 * span^k2 - x_f = left + (right - left) * (fl / (fl - fr)) # Interpolation Step + x_f = left + span * (fl / (fl - fr)) # Interpolation Step + δ = max(k1 * span^k2, eps(x_f)) diff = mid - x_f - σ = sign(diff) - xt = ifelse(δ ≤ diff, x_f + σ * δ, mid) # Truncation Step - xp = ifelse(abs(xt - mid) ≤ r, xt, mid - σ * r) # Projection Step + xt = ifelse(δ ≤ abs(diff), x_f + copysign(δ, diff), mid) # Truncation Step - if abs((left - right) / 2) < ϵ + xp = ifelse(abs(xt - mid) ≤ r, xt, mid - copysign(r, diff)) # Projection Step + if span < 2ϵ return SciMLBase.build_solution( prob, alg, xt, f(xt); retcode = ReturnCode.Success, left, right ) end - - # update - tmin, tmax = minmax(xt, xp) - xp ≥ tmax && (xp = prevfloat(tmax)) - xp ≤ tmin && (xp = nextfloat(tmin)) yp = f(xp) yps = yp * sign(fr) - T0 = zero(yps) if yps > T0 right, fr = xp, yp elseif yps < T0 @@ -143,10 +132,9 @@ function CommonSolve.solve( end i += 1 - mid = (left + right) / 2 ϵ_s /= 2 - if Impl.nextfloat_tdir(left, prob.tspan...) == right + if nextfloat(left) == right return SciMLBase.build_solution( prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right ) diff --git a/lib/BracketingNonlinearSolve/test/rootfind_tests.jl b/lib/BracketingNonlinearSolve/test/rootfind_tests.jl index 6a490d6fa..9161bdfe6 100644 --- a/lib/BracketingNonlinearSolve/test/rootfind_tests.jl +++ b/lib/BracketingNonlinearSolve/test/rootfind_tests.jl @@ -53,7 +53,7 @@ end ϵ = eps(Float64) # least possible tol for all methods @testset for alg in (Bisection(), Falsi(), ITP(), nothing) - @testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7] + @testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6] sol = solve(prob, alg; abstol) result_tol = abs(sol.u - sqrt(2)) @test result_tol < abstol