Skip to content

Commit 570940c

Browse files
authored
close #446; address type instability in init_options (#447)
* close #446; address type instability in init_options * second check
1 parent a2e58ee commit 570940c

File tree

6 files changed

+72
-9
lines changed

6 files changed

+72
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Roots"
22
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
3-
version = "2.2.0"
3+
version = "2.2.1"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/Bracketing/alefeld_potra_shi.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ function init_state(::AbstractAlefeldPotraShi, F, x₀, x₁, fx₀, fx₁; c=no
8181
AbstractAlefeldPotraShiState(promote(b, a, d, ee)..., promote(fb, fa, fd, fe)...)
8282
end
8383

84+
# avoid type-stability issue due to dynamic dispatch based on kwargs
85+
function init_options(
86+
M::AbstractAlefeldPotraShi,
87+
state::AbstractUnivariateZeroState{T,S};
88+
kwargs...,
89+
) where {T,S}
90+
d = kwargs
91+
defs = default_tolerances(M, T, S)
92+
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
93+
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
94+
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
95+
strict = get(d, :strict, defs[6])
96+
Roots.FExactOptions(δₐ, δᵣ, maxiters, strict)
97+
end
98+
99+
84100
# fn calls w/in calculateΔ
85101
# 1 is default, but this should be adjusted for different methods
86102
fncalls_per_step(::AbstractAlefeldPotraShi) = 1

src/Bracketing/bracketing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ end
3232
function default_tolerances(::AbstractBracketingMethod, ::Type{T}, ::Type{S}) where {T,S}
3333
xatol = eps(real(T))^3 * oneunit(real(T))
3434
xrtol = eps(real(T)) # unitless
35-
atol = 0 * oneunit(real(S))
36-
rtol = 0 * one(real(S))
35+
atol = zero(oneunit(real(S)))
36+
rtol = zero(one(real(S)))
3737
maxevals = 60
3838
strict = true
3939
(xatol, xrtol, atol, rtol, maxevals, strict)

src/Bracketing/itp.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,25 @@ function init_state(M::ITP, F, x₀, x₁, fx₀, fx₁)
6767
ITPState(promote(b, a)..., promote(fb, fa)..., 0, ϵ2n₁₂, a)
6868
end
6969

70+
function init_options(
71+
M::ITP,
72+
state::AbstractUnivariateZeroState{T,S};
73+
kwargs...,
74+
) where {T,S}
75+
76+
d = kwargs
77+
defs = default_tolerances(M, T, S)
78+
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
79+
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
80+
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
81+
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
82+
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
83+
strict = get(d, :strict, defs[6])
84+
85+
return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
86+
end
87+
88+
7089
function update_state(M::ITP, F, o::ITPState{T,S,R}, options, l=NullTracks()) where {T,S,R}
7190
a, b = o.xn0, o.xn1
7291
fa, fb = o.fxn0, o.fxn1

src/convergence.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,45 @@ init_options(
3636
kwargs...,
3737
) where {T,S} = init_options(M, T, S; kwargs...)
3838

39+
# this function is an issue (#446) it is type unstable.
40+
# this is a fall back now, but in #446 more
41+
# specific choices based on M are made.
3942
function init_options(M, T=Float64, S=Float64; kwargs...)
4043
d = kwargs
41-
4244
defs = default_tolerances(M, T, S)
4345
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
4446
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
4547
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
4648
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
47-
M = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
49+
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
4850
strict = get(d, :strict, defs[6])
4951

50-
iszero(δₐ) && iszero(δᵣ) && iszero(ϵₐ) && iszero(ϵᵣ) && return ExactOptions(M, strict)
51-
iszero(δₐ) && iszero(δᵣ) && return XExactOptions(ϵₐ, ϵᵣ, M, strict)
52-
iszero(ϵₐ) && iszero(ϵᵣ) && return FExactOptions(δₐ, δᵣ, M, strict)
52+
iszero(δₐ) && iszero(δᵣ) && iszero(ϵₐ) && iszero(ϵᵣ) && return ExactOptions(maxiters, strict)
53+
iszero(δₐ) && iszero(δᵣ) && return XExactOptions(ϵₐ, ϵᵣ, maxiters, strict)
54+
iszero(ϵₐ) && iszero(ϵᵣ) && return FExactOptions(δₐ, δᵣ, maxiters, strict)
5355

54-
return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, M, strict)
56+
return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
5557
end
5658

59+
function init_options(
60+
M::AbstractNonBracketingMethod,
61+
state::AbstractUnivariateZeroState{T,S};
62+
kwargs...,
63+
) where {T,S}
64+
65+
d = kwargs
66+
defs = default_tolerances(M, T, S)
67+
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
68+
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
69+
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
70+
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
71+
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
72+
strict = get(d, :strict, defs[6])
73+
74+
return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
75+
end
76+
77+
5778
## --------------------------------------------------
5879

5980
"""
@@ -309,6 +330,10 @@ function decide_convergence(
309330
#_is_f_approx_0(fxn1, xn1, options.abstol, options.reltol) && return xn1
310331
else
311332
if val == :x_converged
333+
# The XExact case isn't always spelled out in the type, so
334+
# we replicate a bit here
335+
δ, ϵ = options.abstol, options.reltol
336+
iszero(δ) && iszero(ϵ) && return xn1
312337
is_approx_zero_f(M, state, options, true) && return xn1
313338
elseif val == :not_converged
314339
# this is the case where runaway can happen

test/test_allocations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import BenchmarkTools
44
@testset "solve: zero allocations" begin
55
fs = (sin, cos, x -> -sin(x))
66
x0 = (3, 4)
7+
x0′ = big.(x0)
78
Ms = (
89
Order0(),
910
Order1(),
@@ -23,9 +24,11 @@ import BenchmarkTools
2324
Ns = (Roots.Newton(), Roots.Halley(), Roots.Schroder())
2425
for M in Ms
2526
@test BenchmarkTools.@ballocated(solve(ZeroProblem($fs, $x0), $M)) == 0
27+
@inferred solve(ZeroProblem(fs, x0′), M)
2628
end
2729
for M in Ns
2830
@test BenchmarkTools.@ballocated(solve(ZeroProblem($fs, $x0), $M)) == 0
31+
@inferred solve(ZeroProblem(fs, x0′), M)
2932
end
3033

3134
# Allocations in Lith

0 commit comments

Comments
 (0)