Skip to content

Commit 8ee551f

Browse files
author
oscarddssmith
committed
more ITP simplification
1 parent 5e83e35 commit 8ee551f

File tree

1 file changed

+20
-25
lines changed
  • lib/BracketingNonlinearSolve/src

1 file changed

+20
-25
lines changed

lib/BracketingNonlinearSolve/src/itp.jl

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,21 @@ end
4949
function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10)
5050
scaled_k1 < 0 && error("Hyper-parameter κ₁ should not be negative")
5151
n0 < 0 && error("Hyper-parameter n₀ should not be negative")
52-
if k2 < 1 || k2 > (1.5 + sqrt(5) / 2)
52+
if !(1 <= k2 <= 1.5 + sqrt(5) / 2)
5353
throw(ArgumentError("Hyper-parameter κ₂ should be between 1 and 1 + ϕ where \
5454
ϕ ≈ 1.618... is the golden ratio"))
5555
end
5656
return ITP(scaled_k1, k2, n0)
5757
end
5858

59-
function CommonSolve.solve(
59+
@muladd function CommonSolve.solve(
6060
prob::IntervalNonlinearProblem, alg::ITP, args...;
6161
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
6262
)
6363
@assert !SciMLBase.isinplace(prob) "`ITP` only supports out-of-place problems."
6464

6565
f = Base.Fix2(prob.f, prob.p)
66-
left, right = prob.tspan
66+
left, right = minmax(prob.tspan)
6767
fl, fr = f(left), f(right)
6868

6969
abstol = NonlinearSolveBase.get_tolerance(
@@ -93,42 +93,38 @@ function CommonSolve.solve(
9393

9494
ϵ = abstol
9595
k2 = alg.k2
96-
k1 = alg.scaled_k1 * abs(right - left)^(1 - k2)
96+
span = right - left
97+
k1 = alg.scaled_k1 * span^(1 - k2) # k1 > 0
9798
n0 = alg.n0
98-
n_h = ceil(log2(abs(right - left) / (2 * ϵ)))
99-
mid = (left + right) / 2
100-
x_f = left + (right - left) * (fl / (fl - fr))
101-
xt = left
102-
xp = left
103-
r = zero(left) # minmax radius
104-
δ = zero(left) # truncation error
105-
σ = 1.0
106-
ϵ_s = ϵ * 2^(n_h + n0)
99+
n_h = exponent(span / (2 * ϵ)) + 1
100+
ϵ_s = ϵ * exp2(n_h + n0)
107101

108102
i = 1
109103
while i maxiters
110-
span = abs(right - left)
104+
span = right - left
105+
mid = (left + right) / 2
111106
r = ϵ_s - (span / 2)
112107
δ = k1 * span^k2
113108

114-
x_f = left + (right - left) * (fl / (fl - fr)) # Interpolation Step
109+
x_f = left + span * (fl / (fl - fr)) # Interpolation Step
115110

116-
diff = mid - x_f
117-
σ = sign(diff)
118-
xt = ifelse diff, x_f + σ * δ, mid) # Truncation Step
111+
diff = abs(mid - x_f)
112+
xt = ifelse diff, x_f + copysign(δ, diff), mid) # Truncation Step
119113

120-
xp = ifelse(abs(xt - mid) r, xt, mid - σ * r) # Projection Step
114+
xp = ifelse(abs(xt - mid) r, xt, mid - copysign(r, diff)) # Projection Step
121115

122-
if abs((left - right) / 2) < ϵ
116+
if span < 2ϵ
123117
return SciMLBase.build_solution(
124118
prob, alg, xt, f(xt); retcode = ReturnCode.Success, left, right
125119
)
126120
end
127121

128122
# update
129-
tmin, tmax = minmax(xt, xp)
130-
xp tmax && (xp = prevfloat(tmax))
131-
xp tmin && (xp = nextfloat(tmin))
123+
if isless(xt, xp)
124+
xp = prevloat(xp)
125+
else
126+
xp = nextfloat(xp)
127+
end
132128
yp = f(xp)
133129
yps = yp * sign(fr)
134130
T0 = zero(yps)
@@ -143,10 +139,9 @@ function CommonSolve.solve(
143139
end
144140

145141
i += 1
146-
mid = (left + right) / 2
147142
ϵ_s /= 2
148143

149-
if Impl.nextfloat_tdir(left, prob.tspan...) == right
144+
if nextfloat(left) == right
150145
return SciMLBase.build_solution(
151146
prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right
152147
)

0 commit comments

Comments
 (0)