Skip to content

Commit 0e687a1

Browse files
author
Chris Geoga
committed
Checkpoint commit with small tweaks. Still chasing down a small issue
that makes the AD just a bit less accurate than it should be.
1 parent 3c30394 commit 0e687a1

File tree

5 files changed

+60
-8
lines changed

5 files changed

+60
-8
lines changed

ext/BesselsEnzymeCoreExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,17 @@ module BesselsEnzymeCoreExt
3535
Duplicated(ls, dls)
3636
end
3737

38+
# This is fixing a straight bug in Enzyme.
39+
function EnzymeRules.forward(func::Const{typeof(sinpi)},
40+
::Type{<:Duplicated},
41+
x::Duplicated)
42+
Duplicated(sinpi(x.val), pi*cospi(x.val))
43+
end
44+
45+
function EnzymeRules.forward(func::Const{typeof(sinpi)},
46+
::Type{<:Const},
47+
x::Const)
48+
sinpi(x.val)
49+
end
50+
3851
end

src/BesselFunctions/besselk.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,16 +654,17 @@ function besselk_power_series_temme_basal(v::V, x::Float64) where{V}
654654
end
655655

656656
function besselk_power_series_int(v, x::Float64)
657-
v < zero(v) && return besselk_power_series_int(-v, x)
658-
flv = Int(floor(v))
659-
_v = v - flv
657+
v = abs(v)
658+
(_v, flv) = modf(v)
659+
if _v > 1/2
660+
(_v, flv) = (_v-one(_v), flv+1)
661+
end
660662
(kv, kvp1) = besselk_power_series_temme_basal(_v, x)
661-
abs(v) < 1/2 && return kv
662663
twodx = 2/x
663-
for _ in 1:(flv-1)
664+
for _ in 1:flv
664665
_v += 1
665666
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv))
666667
end
667-
kvp1
668+
kv
668669
end
669670

src/Math/Math.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,6 @@ end
155155
end
156156

157157
# TODO (cg 2023/05/16 18:09): dispute this cutoff.
158-
isnearint(x) = abs(x-round(x)) < 1e-7
158+
isnearint(x) = abs(x-round(x)) < 1e-5
159159

160160
end

test/besselk_enzyme_test.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using EnzymeCore, Enzyme
2+
import Bessels.BesselFunctions: besselkx_levin
3+
import Bessels.BesselFunctions: besselk_power_series
4+
5+
dbesselkx_dv(v, x) = autodiff(Forward, _v->besselkx_levin(_v, x, Val(30)),
6+
Duplicated, Duplicated(v, 1.0))[2]
7+
8+
dbesselkx_dx(v, x) = autodiff(Forward, _x->besselkx_levin(v, _x, Val(30)),
9+
Duplicated, Duplicated(x, 1.0))[2]
10+
11+
#=
12+
dbesselk_ps_dv(v, x) = autodiff(Forward, _v->besselk_power_series(_v, x),
13+
Duplicated, Duplicated(v, 1.0))[2]
14+
15+
dbesselk_ps_dx(v, x) = autodiff(Forward, _x->besselk_power_series(v, _x),
16+
Duplicated, Duplicated(x, 1.0))[2]
17+
=#
18+
19+
20+
for line in eachline("data/besselk/enzyme/besselkx_levin_enzyme_tests.csv")
21+
(v, x, dv, dx) = parse.(Float64, split(line))
22+
test_dv = dbesselkx_dv(v, x)
23+
test_dx = dbesselkx_dx(v, x)
24+
@test isapprox(dv, test_dv, rtol=5e-14)
25+
@test isapprox(dx, test_dx, rtol=5e-14)
26+
end
27+
28+
#=
29+
for line in eachline("data/besselk/enzyme/besselk_power_series_enzyme_tests.csv")
30+
(v, x, dv, dx) = parse.(Float64, split(line))
31+
test_dv = dbesselk_ps_dv(v, x)
32+
test_dx = dbesselk_ps_dx(v, x)
33+
#@test isapprox(dv, test_dv, rtol=5e-14)
34+
#@test isapprox(dx, test_dx, rtol=5e-14)
35+
36+
end
37+
=#
38+

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ import SpecialFunctions
1010
@time @testset "gamma" begin include("gamma_test.jl") end
1111
@time @testset "airy" begin include("airy_test.jl") end
1212
@time @testset "sphericalbessel" begin include("sphericalbessel_test.jl") end
13-
@time @testset "enzyme autodiff" begin include("enzyme_test.jl") end
13+
@time @testset "besselk enzyme autodiff" begin include("besselk_enzyme_test.jl") end

0 commit comments

Comments
 (0)