Skip to content

Commit 352ee45

Browse files
author
Chris Geoga
committed
Tests for besseli+Enzyme that cover the new convergence checker.
1 parent 77a860d commit 352ee45

File tree

5 files changed

+674
-11
lines changed

5 files changed

+674
-11
lines changed

ext/BesselsEnzymeCoreExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ module BesselsEnzymeCoreExt
2121
function EnzymeRules.forward(func::Const{typeof(check_convergence)},
2222
::Type{Const{Bool}},
2323
t::Duplicated{T}) where{T}
24-
(t.val <= eps(T)) && (t.dval <= eps(T))
24+
check_convergence(t.val) && check_convergence(t.dval)
2525
end
2626

2727
function EnzymeRules.forward(func::Const{typeof(check_convergence)},
2828
::Type{Const{Bool}},
2929
t::Duplicated{T},
3030
s::Duplicated{T}) where{T}
31-
(t.val <= eps(T)*s.val) && (t.dval <= eps(T)*s.dval)
31+
check_convergence(t.val, s.val) && check_convergence(t.dval, s.val)
3232
end
3333

3434
# This will be fixed upstream: see #861 for Enzyme.jl whenever the next

test/besseli_enzyme_test.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using EnzymeCore, Enzyme
2+
3+
dbesseli_dv(v, x) = autodiff(Forward, _v->besseli(_v, x),
4+
Duplicated, Duplicated(v, 1.0))[2]
5+
6+
dbesseli_dx(v, x) = autodiff(Forward, _x->besseli(v, _x),
7+
Duplicated, Duplicated(x, 1.0))[2]
8+
9+
10+
for line in eachline("data/besseli/enzyme/besseli_enzyme_tests.csv")
11+
(v, x, dv, dx) = parse.(Float64, split(line))
12+
test_dv = dbesseli_dv(v, x)
13+
test_dx = dbesseli_dx(v, x)
14+
# TODO (cg 2023/05/30 12:09): temporarily test at lower rtols in the v=0.001
15+
# case. The power series code tests for convergence scaled by the series
16+
# value itself, which costs a little bit of rtol. When x is about 20, we
17+
# just barely hit that edge regime where we switch to the large argument
18+
# expansion and that edge zone also costs a digit. Those are things to
19+
# discuss addressing in a different PR I think.
20+
if v == 0.001 && x < 21.0
21+
@test isapprox(dv, test_dv, rtol=1e-11)
22+
@test isapprox(dx, test_dx, rtol=1e-11)
23+
else
24+
@test isapprox(dv, test_dv, rtol=5e-14)
25+
@test isapprox(dx, test_dx, rtol=5e-14)
26+
end
27+
end

test/data/besseli/enzyme/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
```julia
2+
using ArbNumerics, DelimitedFiles, SpecialFunctions, Enzyme, Bessels
3+
4+
# Because you can't just use FiniteDifferences due to the "numerical noise".
5+
simplefd(f,x,h=ArbReal(1e-40)) = (f(x+h)-f(x))/h
6+
7+
ArbNumerics.setprecision(ArbReal, digits=100)
8+
arb_besseli(v,x) = ArbNumerics.besseli(ArbReal(v), ArbReal(x))
9+
10+
# TODO (cg 2023/05/30 11:44): temporarily, we make vgrid[end]=15.1 instead of
11+
# 15.0 or some other integer because the integer-order `besseli` derivatives
12+
# need to be fixed with dispatch, like `besselk`.
13+
vgrid = range(1e-3, 15.1, length=20)
14+
xgrid = range(1e-3, 100.0, length=30)
15+
vx = vec(collect(Iterators.product(vgrid, xgrid)))
16+
17+
our_dx(v,x) = autodiff(Forward, _x->Bessels.besseli(v, _x), Duplicated, Duplicated(x, 1.0))[2]
18+
our_dv(v,x) = autodiff(Forward, _v->Bessels.besseli(_v, x), Duplicated, Duplicated(v, 1.0))[2]
19+
20+
if !isinteractive()
21+
ref_values = map(vx) do vxj
22+
(v,x) = vxj
23+
dx = simplefd(_x->arb_besseli(v, _x), x)
24+
dv = simplefd(_v->arb_besseli(_v, x), v)
25+
Float64.((dv, dx))
26+
end
27+
28+
out_matrix = hcat(getindex.(vx, 1), # test v argument
29+
getindex.(vx, 2), # test x argument
30+
getindex.(ref_values, 1), # test d/dv value
31+
getindex.(ref_values, 2)) # test d/dx value
32+
33+
writedlm("besseli_enzyme_tests.csv", out_matrix)
34+
end
35+
```

0 commit comments

Comments
 (0)