Skip to content

Commit a465c7c

Browse files
author
oscarddssmith
committed
more ITP simplification
1 parent 14499b2 commit a465c7c

File tree

1 file changed

+24
-33
lines changed
  • lib/BracketingNonlinearSolve/src

1 file changed

+24
-33
lines changed

lib/BracketingNonlinearSolve/src/itp.jl

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ end
4949
function ITP(; scaled_k1::Real = 0.2, k2::Real = 2, n0::Int = 10)
5050
scaled_k1 < 0 && error("Hyper-parameter κ₁ should not be negative")
5151
n0 < 0 && error("Hyper-parameter n₀ should not be negative")
52-
if k2 < 1 || k2 > (1.5 + sqrt(5) / 2)
52+
if !(1 <= k2 <= 1.5 + sqrt(5) / 2)
5353
throw(ArgumentError("Hyper-parameter κ₂ should be between 1 and 1 + ϕ where \
5454
ϕ ≈ 1.618... is the golden ratio"))
5555
end
@@ -63,22 +63,23 @@ function CommonSolve.solve(
6363
@assert !SciMLBase.isinplace(prob) "`ITP` only supports out-of-place problems."
6464

6565
f = Base.Fix2(prob.f, prob.p)
66-
left, right = prob.tspan
66+
left, right = minmax(promote(prob.tspan...)...)
6767
fl, fr = f(left), f(right)
6868

6969
abstol = NonlinearSolveBase.get_tolerance(
7070
left, abstol, promote_type(eltype(left), eltype(right))
7171
)
7272

73+
stats = SciMLBase.NLStats(2,0,0,0,0)
7374
if iszero(fl)
7475
return SciMLBase.build_solution(
75-
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right
76+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right, stats
7677
)
7778
end
7879

7980
if iszero(fr)
8081
return SciMLBase.build_solution(
81-
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right
82+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right, stats
8283
)
8384
end
8485

@@ -87,73 +88,63 @@ function CommonSolve.solve(
8788
@warn "The interval is not an enclosing interval, opposite signs at the \
8889
boundaries are required."
8990
return SciMLBase.build_solution(
90-
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right
91+
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right, stats
9192
)
9293
end
9394

9495
ϵ = abstol
9596
k2 = alg.k2
96-
k1 = alg.scaled_k1 * abs(right - left)^(1 - k2)
97+
span = right - left
98+
k1 = alg.scaled_k1 * span^(1 - k2) # k1 > 0
9799
n0 = alg.n0
98-
mid = (left + right) / 2
99-
x_f = left + (right - left) * (fl / (fl - fr))
100-
xt = left
101-
xp = left
102-
r = zero(left) # minmax radius
103-
δ = zero(left) # truncation error
104-
σ = one(mid)
105-
n_h = exponent(abs(right - left) / (2 * ϵ))
100+
n_h = exponent(span / (2 * ϵ))
106101
ϵ_s = ϵ * exp2(n_h + n0)
102+
T0 = zero(fl)
107103

108104
i = 1
109105
while i maxiters
110-
span = abs(right - left)
106+
stats.nsteps += 1
107+
span = right - left
108+
mid = (left + right) / 2
111109
r = ϵ_s - (span / 2)
112-
δ = k1 * span^k2
113110

114-
x_f = left + (right - left) * (fl / (fl - fr)) # Interpolation Step
111+
x_f = left + span * (fl / (fl - fr)) # Interpolation Step
115112

113+
δ = max(k1 * span^k2, eps(x_f))
116114
diff = mid - x_f
117-
σ = sign(diff)
118-
xt = ifelse diff, x_f + σ * δ, mid) # Truncation Step
119115

120-
xp = ifelse(abs(xt - mid) r, xt, mid - σ * r) # Projection Step
116+
xt = ifelse(δ abs(diff), x_f + copysign(δ, diff), mid) # Truncation Step
121117

122-
if abs((left - right) / 2) < ϵ
118+
xp = ifelse(abs(xt - mid) r, xt, mid - copysign(r, diff)) # Projection Step
119+
if span < 2ϵ
123120
return SciMLBase.build_solution(
124-
prob, alg, xt, f(xt); retcode = ReturnCode.Success, left, right
121+
prob, alg, xt, f(xt); retcode = ReturnCode.Success, left, right, stats
125122
)
126123
end
127-
128-
# update
129-
tmin, tmax = minmax(xt, xp)
130-
xp tmax && (xp = prevfloat(tmax))
131-
xp tmin && (xp = nextfloat(tmin))
132124
yp = f(xp)
125+
stats.nf += 1
133126
yps = yp * sign(fr)
134-
T0 = zero(yps)
135127
if yps > T0
136128
right, fr = xp, yp
137129
elseif yps < T0
138130
left, fl = xp, yp
139131
else
140132
return SciMLBase.build_solution(
141-
prob, alg, xp, yps; retcode = ReturnCode.Success, left, right
133+
prob, alg, xp, yps; retcode = ReturnCode.Success, left, right, stats
142134
)
143135
end
144136

145137
i += 1
146-
mid = (left + right) / 2
147138
ϵ_s /= 2
148139

149-
if Impl.nextfloat_tdir(left, prob.tspan...) == right
140+
if nextfloat(left) == right
150141
return SciMLBase.build_solution(
151-
prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right
142+
prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right, stats
152143
)
153144
end
154145
end
155146

156147
return SciMLBase.build_solution(
157-
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right
148+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right, stats
158149
)
159150
end

0 commit comments

Comments
 (0)