Skip to content

Commit 1a72f34

Browse files
authored
Add float (#464)
* version bump * add float for conversion at spots
1 parent d8c67cd commit 1a72f34

File tree

8 files changed

+47
-15
lines changed

8 files changed

+47
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Roots"
22
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
3-
version = "2.2.5"
3+
version = "2.2.7"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/Bracketing/alefeld_potra_shi.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function init_state(::AbstractAlefeldPotraShi, F, x₀, x₁, fx₀, fx₁; c=no
6666
end
6767

6868
if c === nothing # need c, fc to be defined if one is
69-
c = a < zero(a) < b ? _middle(a, b) : secant_step(a, b, fa, fb)
69+
c = float(a < zero(a) < b ? _middle(a, b) : secant_step(a, b, fa, fb))
7070
fc = first(F(c))
7171
end
7272

@@ -111,7 +111,7 @@ function update_state(
111111
l=NullTracks(),
112112
) where {T,S}
113113
atol, rtol = options.xabstol, options.xreltol
114-
μ, λ = oftype(rtol, 0.5), oftype(rtol, 0.7)
114+
μ, λ = oftype(float(rtol), 0.5), oftype(float(rtol), 0.7)
115115
tols = (; λ=λ, atol=atol, rtol=rtol)
116116

117117
a::T, b::T, d::T, ee::T = o.xn0, o.xn1, o.d, o.ee
@@ -209,7 +209,7 @@ struct A2425{K} <: AbstractAlefeldPotraShi end
209209
function calculateΔ(::A2425{K}, F::Callable_Function, c₀::T, ps) where {K,T}
210210
a, b, d, ee = ps.a, ps.b, ps.d, ps.ee
211211
fa, fb, fd, fee = ps.fa, ps.fb, ps.fd, ps.fee
212-
tols ==oftype(ps.rtol, 0.7), atol=ps.atol, rtol=ps.rtol)
212+
tols ==oftype(float(ps.rtol), 0.7), atol=ps.atol, rtol=ps.rtol)
213213

214214
c = a
215215
for k in 1:K
@@ -258,7 +258,7 @@ fncalls_per_step(::A57{K}) where {K} = K - 1
258258
function calculateΔ(::A57{K}, F::Callable_Function, c₀::T, ps) where {K,T}
259259
a, b, d, ee = ps.a, ps.b, ps.d, ps.ee
260260
fa, fb, fd, fee = ps.fa, ps.fb, ps.fd, ps.fee
261-
tols ==oftype(ps.rtol, 0.7), atol=ps.atol, rtol=ps.rtol)
261+
tols ==oftype(float(ps.rtol), 0.7), atol=ps.atol, rtol=ps.rtol)
262262
c, fc = a, fa
263263

264264
for k in 1:K

src/Bracketing/brent.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ function update_state(
4848
fa, fb, fc = state.fxn0, state.fxn1, state.fc
4949

5050
# next step depends on points; inverse quadratic
51-
s::T = inverse_quadratic_step(a, b, c, fa, fb, fc)
52-
(isnan(s) || isinf(s)) && (s = secant_step(a, b, fa, fb))
51+
s = float(inverse_quadratic_step(a, b, c, fa, fb, fc))
52+
(isnan(s) || isinf(s)) && (s = float(secant_step(a, b, fa, fb)))
5353

5454
# guard step
5555
u, v = (3a + b) / 4, b

src/Derivative/lith.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ function init_state(
237237
ys₀,
238238
) where {S,D,R,T}
239239
xs, ys = init_lith(L, F, x₁, fx₁, x₀, fx₀, ys₀) # [x₀,x₁,…,xₛ₋₁], ...
240+
240241
# skip unit consideration here, as won't fit within storage of ys
241242
state = LithBoonkkampIJzermanState{S,D + 1,R,T}(
242243
xs[end], # xₙ
@@ -703,17 +704,20 @@ function lmm(::LithBoonkkampIJzerman{2,0}, xs, fs)
703704
x0, x1 = xs
704705
f0, f1 = fs
705706

706-
(f0 * x1 - f1 * x0) / (f0 - f1)
707+
(f0 * x1 - f1 * x0) / (f0 - f1) |> float
707708
end
708709

709710
function lmm(::LithBoonkkampIJzerman{3,0}, xs, fs)
711+
xs, fs
710712
x0, x1, x2 = xs
711713
f0, f1, f2 = fs
712714

713715
(
714716
f0^2 * f1 * x2 - f0^2 * f2 * x1 - f0 * f1^2 * x2 + f0 * f2^2 * x1 + f1^2 * f2 * x0 -
715717
f1 * f2^2 * x0
716-
) / (f0^2 * f1 - f0^2 * f2 - f0 * f1^2 + f0 * f2^2 + f1^2 * f2 - f1 * f2^2)
718+
) / (
719+
f0^2 * f1 - f0^2 * f2 - f0 * f1^2 + f0 * f2^2 + f1^2 * f2 - f1 * f2^2
720+
) |> float
717721
end
718722

719723
function lmm(::LithBoonkkampIJzerman{4,0}, xs, fs)
@@ -745,7 +749,7 @@ function lmm(::LithBoonkkampIJzerman{4,0}, xs, fs)
745749
f0 * f2^3 * f3^2 - f0 * f2^2 * f3^3 - f1^3 * f2^2 * f3 +
746750
f1^3 * f2 * f3^2 +
747751
f1^2 * f2^3 * f3 - f1^2 * f2 * f3^3 - f1 * f2^3 * f3^2 + f1 * f2^2 * f3^3
748-
)
752+
) |> float
749753
end
750754

751755
function lmm(::LithBoonkkampIJzerman{5,0}, xs, fs)
@@ -901,7 +905,7 @@ function lmm(::LithBoonkkampIJzerman{5,0}, xs, fs)
901905
f1 * f2^4 * f3^2 * f4^3 +
902906
f1 * f2^3 * f3^4 * f4^2 - f1 * f2^3 * f3^2 * f4^4 - f1 * f2^2 * f3^4 * f4^3 +
903907
f1 * f2^2 * f3^3 * f4^4
904-
)
908+
) |> float
905909
end
906910

907911
function lmm(::LithBoonkkampIJzerman{6,0}, xs, fs)

src/convergence.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ function is_small_Δx(
235235
return δ Δₓ
236236
end
237237

238+
239+
isnan_x(M::AbstractBracketingMethod, state) = isnan(state.xn1) || isnan(state.xn0)
240+
isnan_x(M::AbstractNonBracketingMethod, state) = isnan(state.xn1)
241+
242+
isinf_x(M::AbstractBracketingMethod, state) = isinf(state.xn1) || isinf(state.xn0)
243+
isinf_x(M::AbstractNonBracketingMethod, state) = isinf(state.xn1)
244+
238245
isnan_f(M::AbstractBracketingMethod, state) = isnan(state.fxn1) || isnan(state.fxn0)
239246
isnan_f(M::AbstractNonBracketingMethod, state) = isnan(state.fxn1)
240247

@@ -271,6 +278,8 @@ In `decide_convergence`, stopped values (and `:x_converged` when `strict=false`)
271278
function assess_convergence(M::Any, state::AbstractUnivariateZeroState, options)
272279
# return convergence_flag, boolean
273280
is_exact_zero_f(M, state, options) && return (:exact_zero, true)
281+
isnan_x(M, state) && return (:nan, true)
282+
isinf_x(M, state) && return (:inf, true)
274283
isnan_f(M, state) && return (:nan, true)
275284
isinf_f(M, state) && return (:inf, true)
276285
is_approx_zero_f(M, state, options) && return (:f_converged, true)
@@ -330,7 +339,23 @@ function decide_convergence(
330339
val (:f_converged, :exact_zero, :converged) && return xn1
331340

332341
## XXX this could be problematic
333-
val == :nan && return xn1
342+
if val == :nan
343+
# return if Δx small
344+
Δₓ = abs(xn1 - xn0)
345+
δₐ, δᵣ = options.xabstol, options.xreltol
346+
u = min(abs(xn0), abs(xn1))
347+
δₓ = max(δₐ, 2 * abs(u) * δᵣ) # needs non-zero δₐ to stop near 0
348+
Δₓ δₓ && return xn1
349+
350+
# or if abs(fxn0) small
351+
ϵₐ, ϵᵣ = options.abstol, options.reltol
352+
Δ = max(_unitless(ϵₐ), _unitless(xn0) * ϵᵣ)
353+
abs(state.fxn0) Δ * oneunit(state.fxn0) && return xn0
354+
355+
# else
356+
return nan(T) * xn1
357+
# return xn1
358+
end
334359
val == :inf_nan && return xn1
335360

336361
## stopping is a heuristic, x_converged can mask issues

src/find_zeros.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ Base.show(io::IO, alpha::Interval) = print(io, "($(alpha.a), $(alpha.b))")
6868

6969
# check if f(a) is non zero using tolerances max(atol, eps()), rtol
7070
function _non_zero(fa, a::T, atol, rtol) where {T}
71-
abs(fa) >= max(atol, abs(a) * rtol * oneunit(fa) / oneunit(a), oneunit(fa) * eps(T))
71+
a, r = atol, abs(a) * rtol * oneunit(fa) / oneunit(a), oneunit(fa) * eps(T)
72+
return abs(fa) >= max(promote(a,r)...)
7273
end
7374

7475
# After splitting by zeros we have intervals (zm, zn) this is used to shrink

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ function quad_vertex(c, fc, b, fb, a, fa)
111111
end
112112

113113
## inverse quadratic
114-
function inverse_quadratic_step(a::T, b, c, fa, fb, fc) where {T}
114+
function inverse_quadratic_step(a::T, b, c, fa::S, fb, fc) where {T,S}
115115
s = zero(T)
116116
s += a * fb * fc / (fa - fb) / (fa - fc) # quad step
117117
s += b * fa * fc / (fb - fa) / (fb - fc)
118118
s += c * fa * fb / (fc - fa) / (fc - fb)
119-
s
119+
float(s)
120120
end
121121

122122
## Different functions for approximating f'(xn)

test/test_find_zero.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ struct Order3_Test <: Roots.AbstractSecantMethod end
3636
@test find_zero(sin, 3.0, m) pi
3737
@test find_zero(sin, big(3), m) pi
3838
@test find_zero(sin, big(3.0), m) pi
39+
@test find_zero(sin, π, m) pi
3940
@test find_zero(x -> x^2 - 2.0f0, 2.0f0, m) sqrt(2) # issue 421
4041
@test isnan(solve(ZeroProblem(x -> x^2 + 2, 0.5f0)))
4142
end
4243

44+
4345
## defaults for method argument
4446
@test find_zero(sin, 3.0) pi # order0()
4547
@test @inferred(find_zero(sin, (3, 4))) π # Bisection()

0 commit comments

Comments
 (0)