Skip to content

Commit 75d01ee

Browse files
authored
Merge pull request #105 from cgeoga/convg_check
`EnzymeRule` for checking convergence in series-type code
2 parents f6d2701 + c8750f3 commit 75d01ee

File tree

9 files changed

+690
-13
lines changed

9 files changed

+690
-13
lines changed

ext/BesselsEnzymeCoreExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module BesselsEnzymeCoreExt
33
using Bessels, EnzymeCore
44
using EnzymeCore.EnzymeRules
55
using Bessels.Math
6+
import Bessels.Math: check_convergence
67

78
# A manual method that separately transforms the `val` and `dval`, because
89
# sometimes the `val` can converge while the `dval` hasn't, so just using an
@@ -17,18 +18,36 @@ module BesselsEnzymeCoreExt
1718
Duplicated(ls, dls)
1819
end
1920

20-
# This is fixing a straight bug in Enzyme.
21+
function EnzymeRules.forward(func::Const{typeof(check_convergence)},
22+
::Type{Const{Bool}},
23+
t::Duplicated{T}) where{T}
24+
check_convergence(t.val) && check_convergence(t.dval)
25+
end
26+
27+
function EnzymeRules.forward(func::Const{typeof(check_convergence)},
28+
::Type{Const{Bool}},
29+
t::Duplicated{T},
30+
s::Duplicated{T}) where{T}
31+
check_convergence(t.val, s.val) && check_convergence(t.dval, s.val)
32+
end
33+
34+
# This will be fixed upstream: see #861 for Enzyme.jl whenever the next
35+
# release occurs.
2136
function EnzymeRules.forward(func::Const{typeof(sinpi)},
2237
::Type{<:Duplicated},
2338
x::Duplicated)
2439
(sp, cp) = sincospi(x.val)
2540
Duplicated(sp, pi*cp*x.dval)
2641
end
2742

43+
# #861 will probably also mean this can be deleted at the next release of
44+
# Enzyme.jl.
2845
function EnzymeRules.forward(func::Const{typeof(sinpi)},
2946
::Type{<:Const},
3047
x::Const)
3148
sinpi(x.val)
3249
end
3350

51+
52+
3453
end

src/BesselFunctions/besseli.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ _besseli(nu::Union{Int16, Float16}, x::Union{Int16, Float16}) = Float16(_besseli
414414
_besseli(nu::AbstractRange, x::T) where T = besseli!(zeros(T, length(nu)), nu, x)
415415

416416
function _besseli(nu::T, x::T) where T <: Union{Float32, Float64}
417-
isinteger(nu) && return _besseli(Int(nu), x)
418417
~isfinite(x) && return x
419418
abs_nu = abs(nu)
420419
abs_x = abs(x)
@@ -602,7 +601,7 @@ function besselix_large_args(v, x::ComplexOrReal{T}) where T
602601
for i in 1:MaxIter
603602
t *= -invx * ((4*v^2 - (2i - 1)^2) / i)
604603
s += t
605-
abs(t) <= eps(T) && break
604+
Math.check_convergence(t) && break
606605
end
607606
return s / sqrt(2 ** x))
608607
end
@@ -624,7 +623,7 @@ function besseli_power_series(v, x::ComplexOrReal{T}) where T
624623
xx = x * x * T(0.25)
625624
for i in 0:MaxIter
626625
s += t
627-
abs(t) < eps(T) * abs(s) && break
626+
Math.check_convergence(t, s) && break
628627
t *= xx / ((v + i + 1) * (i + 1))
629628
end
630629
return s * ((x/2)^v / gamma(v + 1))

src/BesselFunctions/besselk.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ function besselkx_large_args(v, x::ComplexOrReal{T}) where T
458458
for i in 1:MaxIter
459459
t *= invx * ((4*v^2 - (2i - 1)^2) / i)
460460
s += t
461-
abs(t) <= eps(T) && break
461+
Math.check_convergence(t) && break
462462
end
463463
return s * sqrt/ (2 * x))
464464
end
@@ -552,7 +552,7 @@ function besselk_power_series(v, x::ComplexOrReal{T}) where T
552552
s2 += t2
553553
t1 *= x^2 / (4k * (k - v))
554554
t2 *= x^2 / (4k * (k + v))
555-
abs(t1) < eps(T) && break
555+
Math.check_convergence(t1) && break
556556
end
557557

558558
xpv = (x/2)^v
@@ -590,8 +590,7 @@ function besselk_temme_series(v::T, x::T) where T <: Float64
590590
term_vp1 = ck * (pk - (k-1) * fk)
591591
out_v += term_v
592592
out_vp1 += term_vp1
593-
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break
594-
593+
(Math.check_convergence(term_v) && Math.check_convergence(term_vp1)) && break
595594
fk = (k * fk + pk + qk) / (k^2 - v^2)
596595
pk /= k - v
597596
qk /= k + v

src/Math/Math.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,10 @@ end
154154
)
155155
end
156156

157-
# TODO (cg 2023/05/16 18:09): dispute this cutoff.
158157
isnearint(x) = abs(x-round(x)) < 1e-5
159158

159+
@inline check_convergence(term::T) where T = abs(term) <= eps(T)
160+
161+
@inline check_convergence(ser::T, term::T) where T = (abs(term) <= eps(T)*ser)
162+
160163
end

test/besseli_enzyme_test.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using EnzymeCore, Enzyme
2+
3+
dbesseli_dv(v, x) = autodiff(Forward, _v->besseli(_v, x),
4+
Duplicated, Duplicated(v, 1.0))[2]
5+
6+
dbesseli_dx(v, x) = autodiff(Forward, _x->besseli(v, _x),
7+
Duplicated, Duplicated(x, 1.0))[2]
8+
9+
10+
for line in eachline("data/besseli/enzyme/besseli_enzyme_tests.csv")
11+
(v, x, dv, dx) = parse.(Float64, split(line))
12+
test_dv = dbesseli_dv(v, x)
13+
test_dx = dbesseli_dx(v, x)
14+
# TODO (cg 2023/05/30 12:09): temporarily test at lower rtols in the v=0.001
15+
# case. The power series code tests for convergence scaled by the series
16+
# value itself, which costs a little bit of rtol. When x is about 20, we
17+
# just barely hit that edge regime where we switch to the large argument
18+
# expansion and that edge zone also costs a digit. Those are things to
19+
# discuss addressing in a different PR I think.
20+
if v == 0.001 && x < 21.0
21+
@test isapprox(dv, test_dv, rtol=1e-11)
22+
@test isapprox(dx, test_dx, rtol=1e-11)
23+
else
24+
@test isapprox(dv, test_dv, rtol=5e-14)
25+
@test isapprox(dx, test_dx, rtol=5e-14)
26+
end
27+
end

test/besseli_test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ end
131131
(v, x) = 12.0, 3.2
132132
@test besseli(v,x) 7.1455266650203694069897133431e-7
133133

134-
(v,x) = 13.0, -1.0
134+
# Negative arguments only allowed with v::Int!
135+
(v,x) = 13, -1.0
135136
@test besseli(v,x) -1.995631678207200756444e-14
136137

137138
(v,x) = 12.6, -3.0
@@ -146,8 +147,7 @@ end
146147
(v, x) = -12.3, 8.2
147148
@test besseli(v,x) 0.267079696793126091886043602895
148149

149-
(v, x) = -14.0, -9.9
150+
# Negative arguments only allowed with v::Int!
151+
(v, x) = -14, -9.9
150152
@test besseli(v,x) 0.2892290867115615816280234648
151153

152-
(v, x) = -14.6, -10.6
153-
#@test besseli(v,x) ≈ -0.157582642056898598750175404443 - 0.484989503203097528858271270828*im

test/data/besseli/enzyme/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
```julia
2+
using ArbNumerics, DelimitedFiles, SpecialFunctions
3+
4+
# Because you can't just use FiniteDifferences due to the "numerical noise".
5+
simplefd(f,x,h=ArbReal(1e-40)) = (f(x+h)-f(x))/h
6+
7+
ArbNumerics.setprecision(ArbReal, digits=100)
8+
arb_besseli(v,x) = ArbNumerics.besseli(ArbReal(v), ArbReal(x))
9+
10+
vgrid = range(1e-3, 15.0, length=20)
11+
xgrid = range(1e-3, 100.0, length=30)
12+
vx = vec(collect(Iterators.product(vgrid, xgrid)))
13+
14+
if !isinteractive()
15+
ref_values = map(vx) do vxj
16+
(v,x) = vxj
17+
dx = simplefd(_x->arb_besseli(v, _x), x)
18+
dv = simplefd(_v->arb_besseli(_v, x), v)
19+
Float64.((dv, dx))
20+
end
21+
22+
out_matrix = hcat(getindex.(vx, 1), # test v argument
23+
getindex.(vx, 2), # test x argument
24+
getindex.(ref_values, 1), # test d/dv value
25+
getindex.(ref_values, 2)) # test d/dx value
26+
27+
writedlm("besseli_enzyme_tests.csv", out_matrix)
28+
end
29+
```

0 commit comments

Comments
 (0)