Skip to content

Commit fbd0ad4

Browse files
authored
fix issue with non type-stable f (#400)
* fix issue with non type-stable f * cleaner approach? * extension and 1.6
1 parent 51271a6 commit fbd0ad4

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Roots"
22
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
3-
version = "2.0.19"
3+
version = "2.0.20"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -19,7 +19,6 @@ SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
1919
ChainRulesCore = "1"
2020
CommonSolve = "0.1, 0.2"
2121
ForwardDiff = "0.10"
22-
IntervalRootFinding = "0.5"
2322
SymPy = "1"
2423
SymPyPythonCall = "0.1,1"
2524
Setfield = "0.7, 0.8, 1"
@@ -37,7 +36,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3736
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3837
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3938
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
40-
IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
39+
#IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
4140
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
4241
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
4342
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -49,4 +48,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4948
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5049

5150
[targets]
52-
test = ["Aqua", "ChainRulesTestUtils", "JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "Unitful", "Zygote", "IntervalRootFinding"]
51+
test = ["Aqua", "ChainRulesTestUtils", "JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "Unitful", "Zygote"] #, "IntervalRootFinding"]

src/Bracketing/bisection.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ function init_state(
4444
if x₀ > x₁
4545
x₀, x₁, fx₀, fx₁ = x₁, x₀, fx₁, fx₀
4646
end
47-
4847
# handle interval if fa*fb ≥ 0 (explicit, but also not needed)
4948
(iszero(fx₀) || iszero(fx₁)) &&
5049
return UnivariateZeroState(promote(x₁, x₀)..., promote(fx₁, fx₀)...)
@@ -57,7 +56,6 @@ function init_state(
5756

5857
# handles case where a=-0.0, b=1.0 without error
5958
sign(a) * sign(b) < 0 && throw(ArgumentError("_middle error"))
60-
6159
UnivariateZeroState(promote(b, a)..., promote(fb, fa)...)
6260
end
6361

@@ -180,10 +178,10 @@ function solve!(
180178
val, stopped = :not_converged, false
181179
ctr = 1
182180
log_step(l, M, state; init=true)
183-
181+
T,S = TS(state)
184182
while !stopped
185-
a, b = state.xn0, state.xn1
186-
fa, fb = state.fxn0, state.fxn1
183+
a::T, b::T = state.xn0, state.xn1
184+
fa::S, fb::S = state.fxn0, state.fxn1
187185

188186
## assess_convergence
189187
if nextfloat(a) b

src/state.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ struct UnivariateZeroState{T,S} <: AbstractUnivariateZeroState{T,S}
66
fxn0::S
77
end
88

9+
TS(::AbstractUnivariateZeroState{T,S}) where {T,S} = T,S
10+
911
# simple helper to set main properties of a state object
1012
function _set(state, xf1)
1113
x, fx = xf1

test/test_extensions.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ end
2020

2121
using ForwardDiff
2222
@testset "ForwardDiff" begin
23-
f(x, p) = x^2 - p
23+
f = (x, p) -> x^2 - p
2424
Z = ZeroProblem(f, (0, 1000))
25-
F(p) = solve(Z, Roots.Bisection(), p)
25+
F = p -> solve(Z, Roots.Bisection(), p)
2626
for p (3,5,7,11)
2727
@test F(p) sqrt(p)
2828
@test ForwardDiff.derivative(F, p) 1 / (2sqrt(p))
2929
end
3030

3131
# Hessian is *broken*
32-
f(x, p) = x^2 - sum(p.^2)
32+
f = (x, p) -> x^2 - sum(p.^2)
3333
Z = ZeroProblem(f, (0, 1000))
34-
F(p) = solve(Z, Roots.Bisection(), p)
34+
F = p -> solve(Z, Roots.Bisection(), p)
3535
Z = ZeroProblem(f, (0, 1000))
36-
F(p) = solve(Z, Roots.Bisection(), p)
36+
F = p -> solve(Z, Roots.Bisection(), p)
3737
hess(f, p) = ForwardDiff.jacobian(p -> ForwardDiff.gradient(F, p), p)
3838
for p ([1,2], [1,3], [1,4])
3939
@test F(p) sqrt(sum(p.^2))
@@ -42,9 +42,9 @@ using ForwardDiff
4242
n = sqrt(a^2 + b^2)^3
4343
@test hess(F, p) [b^2 -a*b; -a*b a^2]/n
4444
end
45-
4645
end
4746

47+
#=
4848
using IntervalRootFinding
4949
@testset "IntervalRootFinding" begin
5050
f(x) = sin(x + sin(x + sin(x)))
@@ -53,3 +53,4 @@ using IntervalRootFinding
5353
@test sort(out.zeros) ≈ sort([-pi,0,pi])
5454
@test isempty(out.unknown)
5555
end
56+
=#

test/test_find_zero.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ struct Order3_Test <: Roots.AbstractSecantMethod end
118118
@test @inferred(find_zero(sin, SomeInterval(3, 4))) pi
119119
@test @inferred(find_zero(sin, range(3, stop=4, length=20))) pi
120120
end
121+
122+
# test issue when non type stalbe
123+
h(x) = x < 2000 ? -1000 : -1000 + 0.1 * (x - 2000)
124+
a, b, xᵅ = 0, 20_000, 12_000
125+
for M bracketing_meths
126+
@test find_zero(h, (a,b), M) xᵅ
127+
end
128+
121129
end
122130

123131
@testset "non simple zeros" begin

0 commit comments

Comments
 (0)