Skip to content

Commit caf2c2c

Browse files
authored
Merge pull request #97 from cgeoga/enzyme
Add `EnzymeCore` weakdep and an extension with a custom rule for the Levin transformation
2 parents e48c6b4 + 27d257f commit caf2c2c

File tree

15 files changed

+1127
-117
lines changed

15 files changed

+1127
-117
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
matrix:
2222
version:
2323
- '1'
24-
- '1.8'
24+
- '1.9'
2525
- 'nightly'
2626
os:
2727
- ubuntu-latest

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021-2022 Michael Helton, Oscar Smith, and contributors
3+
Copyright (c) 2021-2023 Michael Helton, Oscar Smith, Chris Geoga, and contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ version = "0.3.0-DEV"
66
SIMDMath = "5443be0b-e40a-4f70-a07e-dcd652efc383"
77

88
[compat]
9-
julia = "1.8"
109
SIMDMath = "0.2.5"
10+
julia = "1.9"
1111

1212
[extras]
1313
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1414

15+
[weakdeps]
16+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
17+
18+
[extensions]
19+
BesselsEnzymeCoreExt = "EnzymeCore"
20+
1521
[targets]
1622
test = ["Test"]

ext/BesselsEnzymeCoreExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module BesselsEnzymeCoreExt
2+
3+
using Bessels, EnzymeCore
4+
using EnzymeCore.EnzymeRules
5+
using Bessels.Math
6+
7+
# A manual method that separately transforms the `val` and `dval`, because
8+
# sometimes the `val` can converge while the `dval` hasn't, so just using an
9+
# early return or something can give incorrect derivatives in edge cases.
10+
function EnzymeRules.forward(func::Const{typeof(levin_transform)},
11+
::Type{<:Duplicated},
12+
s::Duplicated,
13+
w::Duplicated)
14+
(sv, dv, N) = (s.val, s.dval, length(s.val))
15+
ls = levin_transform(sv, w.val)
16+
dls = levin_transform(dv, w.dval)
17+
Duplicated(ls, dls)
18+
end
19+
20+
# This is fixing a straight bug in Enzyme.
21+
function EnzymeRules.forward(func::Const{typeof(sinpi)},
22+
::Type{<:Duplicated},
23+
x::Duplicated)
24+
(sp, cp) = sincospi(x.val)
25+
Duplicated(sp, pi*cp*x.dval)
26+
end
27+
28+
function EnzymeRules.forward(func::Const{typeof(sinpi)},
29+
::Type{<:Const},
30+
x::Const)
31+
sinpi(x.val)
32+
end
33+
34+
end

src/BesselFunctions/besselk.jl

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ besselk_power_series(v, x::Float32) = Float32(besselk_power_series(v, Float64(x)
499499
besselk_power_series(v, x::ComplexF32) = ComplexF32(besselk_power_series(v, ComplexF64(x)))
500500

501501
function besselk_power_series(v, x::ComplexOrReal{T}) where T
502+
Math.isnearint(v) && return besselk_power_series_int(v, x)
502503
MaxIter = 1000
503504
S = eltype(x)
504505
v, x = S(v), S(x)
@@ -512,7 +513,8 @@ function besselk_power_series(v, x::ComplexOrReal{T}) where T
512513
# use the reflection identify to calculate gamma(-v)
513514
# use relation gamma(v)*v = gamma(v+1) to avoid two gamma calls
514515
gam_v = gamma(v)
515-
gam_nv = π / (sinpi(-abs(v)) * gam_v * v)
516+
#gam_nv = π / (sin(-pi*abs(v)) * gam_v * v) # not using sinpi here to avoid Enzyme bug
517+
gam_nv = π / (sinpi(-abs(v)) * gam_v * v) # not using sinpi here to avoid Enzyme bug
516518
gam_1mv = -gam_nv * v
517519
gam_1mnv = gam_v * v
518520

@@ -578,15 +580,16 @@ end
578580
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N}
579581
:(
580582
begin
581-
s_0 = zero(T)
583+
s = zero(T)
582584
t = one(T)
583585
@nexprs $N i -> begin
584-
s_{i} = s_{i-1} + t
585-
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
586-
w_{i} = 1 / t
587-
end
588-
sequence = @ntuple $N i -> s_{i}
589-
weights = @ntuple $N i -> w_{i}
586+
s += t
587+
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
588+
s_{i} = s
589+
w_{i} = t
590+
end
591+
sequence = @ntuple $N i -> s_{i}
592+
weights = @ntuple $N i -> w_{i}
590593
return levin_transform(sequence, weights) * sqrt/ 2x)
591594
end
592595
)
@@ -614,3 +617,66 @@ end
614617
end
615618
)
616619
end
620+
621+
# This is a version of Temme's proposed f_0 (1975 JCP, see reference above) that
622+
# swaps in a bunch of local expansions for functions that are well-behaved but
623+
# whose standard forms can't be naively evaluated by a computer at the origin.
624+
@inline function f0_local_expansion_v0(v, x)
625+
l2dx = log(2/x)
626+
mu = v*l2dx
627+
vv = v*v
628+
sp = evalpoly(vv, (1.0, 1.6449340668482264, 1.8940656589944918, 1.9711021825948702))
629+
g1 = evalpoly(vv, (-0.5772156649015329, 0.04200263503409518, 0.042197734555544306))
630+
g2 = evalpoly(vv, (1.0, -0.6558780715202539, 0.16653861138229145))
631+
sh = evalpoly(mu*mu, (1.0, 0.16666666666666666, 0.008333333333333333, 0.0001984126984126984, 2.7557319223985893e-6))
632+
sp*(g1*cosh(mu) + g2*sh*l2dx)
633+
end
634+
635+
# This function assumes |v|<1e-5!
636+
function besselk_power_series_temme_basal(v::V, x::X) where{V,X}
637+
max_iter = 50
638+
T = promote_type(V,X)
639+
z = x/2
640+
zz = z*z
641+
fk = f0_local_expansion_v0(v,x)
642+
zv = z^v
643+
znv = inv(zv)
644+
gam_1_c = (1.0, -0.5772156649015329, 0.9890559953279725, -0.23263776388631713)
645+
gam_1pv = evalpoly(v, gam_1_c)
646+
gam_1nv = evalpoly(-v, gam_1_c)
647+
(pk, qk, _ck, factk, vv) = (znv*gam_1pv/2, zv*gam_1nv/2, one(T), one(T), v*v)
648+
(out_v, out_vp1) = (zero(T), zero(T))
649+
for k in 1:max_iter
650+
# add to the series:
651+
ck = _ck/factk
652+
term_v = ck*fk
653+
term_vp1 = ck*(pk - (k-1)*fk)
654+
out_v += term_v
655+
out_vp1 += term_vp1
656+
# check for convergence:
657+
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break
658+
# otherwise, increment new quantities:
659+
fk = (k*fk + pk + qk)/(k^2 - vv)
660+
pk /= (k-v)
661+
qk /= (k+v)
662+
_ck *= zz
663+
factk *= k
664+
end
665+
(out_v, out_vp1/z)
666+
end
667+
668+
function besselk_power_series_int(v, x::Float64)
669+
v = abs(v)
670+
(_v, flv) = modf(v)
671+
if _v > 1/2
672+
(_v, flv) = (_v-one(_v), flv+1)
673+
end
674+
(kv, kvp1) = besselk_power_series_temme_basal(_v, x)
675+
twodx = 2/x
676+
for _ in 1:flv
677+
_v += 1
678+
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv))
679+
end
680+
kv
681+
end
682+

src/GammaFunctions/gamma.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,5 @@ function gamma(n::Integer)
113113
n > 20 && return gamma(float(n))
114114
@inbounds return Float64(factorial(n-1))
115115
end
116+
117+
gamma_near_1(x) = evalpoly(x-one(x), (1.0, -0.5772156649015329, 0.9890559953279725, -0.23263776388631713))

src/Math/Math.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,12 @@ end
131131
#@inline levin_scale(B::T, n, k) where T = -(B + n) * (B + n + k)^(k - one(T)) / (B + n + k + one(T))^k
132132
@inline levin_scale(B::T, n, k) where T = -(B + n + k) * (B + n + k - 1) / ((B + n + 2k) * (B + n + 2k - 1))
133133

134-
@inline @generated function levin_transform(s::NTuple{N, T}, w::NTuple{N, T}) where {N, T <: FloatTypes}
134+
@inline @generated function levin_transform(s::NTuple{N, T},
135+
w::NTuple{N, T}) where {N, T <: FloatTypes}
135136
len = N - 1
136137
:(
137138
begin
138-
@nexprs $N i -> a_{i} = Vec{2, T}((s[i] * w[i], w[i]))
139+
@nexprs $N i -> a_{i} = iszero(w[i]) ? (return s[i]) : Vec{2, T}((s[i] / w[i], 1 / w[i]))
139140
@nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
140141
return (a_1[1] / a_1[2])
141142
end
@@ -153,4 +154,7 @@ end
153154
)
154155
end
155156

157+
# TODO (cg 2023/05/16 18:09): dispute this cutoff.
158+
isnearint(x) = abs(x-round(x)) < 1e-5
159+
156160
end

0 commit comments

Comments
 (0)