Skip to content

Commit 03b070f

Browse files
author
Chris Geoga
committed
Demo of convergence check with custom Enzyme method.
1 parent 9257c5f commit 03b070f

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

ext/BesselsEnzymeCoreExt.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module BesselsEnzymeCoreExt
33
using Bessels, EnzymeCore
44
using EnzymeCore.EnzymeRules
55
using Bessels.Math
6+
import Bessels.Math: check_convergence
67

78
# A manual method that separately transforms the `val` and `dval`, because
89
# sometimes the `val` can converge while the `dval` hasn't, so just using an
@@ -17,18 +18,29 @@ module BesselsEnzymeCoreExt
1718
Duplicated(ls, dls)
1819
end
1920

20-
# This is fixing a straight bug in Enzyme.
21+
function EnzymeRules.forward(func::Const{typeof(check_convergence)},
22+
::Type{Const{Bool}},
23+
t::Duplicated{T}) where{T}
24+
(t.val <= eps(T)) && (t.dval <= eps(T))
25+
end
26+
27+
# This will be fixed upstream: see #861 for Enzyme.jl whenever the next
28+
# release occurs.
2129
function EnzymeRules.forward(func::Const{typeof(sinpi)},
2230
::Type{<:Duplicated},
2331
x::Duplicated)
2432
(sp, cp) = sincospi(x.val)
2533
Duplicated(sp, pi*cp*x.dval)
2634
end
2735

36+
# #861 will probably also mean this can be deleted at the next release of
37+
# Enzyme.jl.
2838
function EnzymeRules.forward(func::Const{typeof(sinpi)},
2939
::Type{<:Const},
3040
x::Const)
3141
sinpi(x.val)
3242
end
3343

44+
45+
3446
end

src/BesselFunctions/besseli.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ function besselix_large_args(v, x::ComplexOrReal{T}) where T
602602
for i in 1:MaxIter
603603
t *= -invx * ((4*v^2 - (2i - 1)^2) / i)
604604
s += t
605-
abs(t) <= eps(T) && break
605+
Math.check_convergence(t) && break
606606
end
607607
return s / sqrt(2 ** x))
608608
end

src/Math/Math.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,10 @@ end
154154
)
155155
end
156156

157-
# TODO (cg 2023/05/16 18:09): dispute this cutoff.
158157
isnearint(x) = abs(x-round(x)) < 1e-5
159158

159+
@inline check_convergence(term::T) where T = abs(term) <= eps(T)
160+
161+
@inline check_convergence(ser::T, term::T) where T = (abs(term) <= eps(T)*ser)
162+
160163
end

0 commit comments

Comments
 (0)