Skip to content

Commit a48ba24

Browse files
authored
Merge pull request #308 from devmotion/dw/notimplemented
Use `ChainRulesCore.@not_implemented` and extend tests
2 parents feccbf4 + c1f7012 commit a48ba24

File tree

3 files changed

+71
-32
lines changed

3 files changed

+71
-32
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
88
OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
99

1010
[compat]
11-
ChainRulesCore = "0.9"
12-
ChainRulesTestUtils = "0.6.3"
11+
ChainRulesCore = "0.9.40"
12+
ChainRulesTestUtils = "0.6.8"
1313
LogExpFunctions = "0.2"
1414
OpenSpecFun_jll = "0.5"
1515
julia = "1.3"

src/chainrules.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
const BESSEL_ORDER_INFO = """
2+
derivatives of Bessel functions with respect to the order are not implemented currently:
3+
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
4+
"""
5+
16
ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
27
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
38
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
@@ -31,49 +36,49 @@ ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))
3136
ChainRulesCore.@scalar_rule(
3237
besselj(ν, x),
3338
(
34-
ChainRulesCore.@thunk(error("not implemented")),
39+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
3540
(besselj- 1, x) - besselj+ 1, x)) / 2
3641
),
3742
)
3843
ChainRulesCore.@scalar_rule(
3944
besseli(ν, x),
4045
(
41-
ChainRulesCore.@thunk(error("not implemented")),
46+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
4247
(besseli- 1, x) + besseli+ 1, x)) / 2,
4348
),
4449
)
4550
ChainRulesCore.@scalar_rule(
4651
bessely(ν, x),
4752
(
48-
ChainRulesCore.@thunk(error("not implemented")),
53+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
4954
(bessely- 1, x) - bessely+ 1, x)) / 2,
5055
),
5156
)
5257
ChainRulesCore.@scalar_rule(
5358
besselk(ν, x),
5459
(
55-
ChainRulesCore.@thunk(error("not implemented")),
60+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
5661
-(besselk- 1, x) + besselk+ 1, x)) / 2,
5762
),
5863
)
5964
ChainRulesCore.@scalar_rule(
6065
hankelh1(ν, x),
6166
(
62-
ChainRulesCore.@thunk(error("not implemented")),
67+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
6368
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
6469
),
6570
)
6671
ChainRulesCore.@scalar_rule(
6772
hankelh2(ν, x),
6873
(
69-
ChainRulesCore.@thunk(error("not implemented")),
74+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
7075
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
7176
),
7277
)
7378
ChainRulesCore.@scalar_rule(
7479
polygamma(m, x),
7580
(
76-
ChainRulesCore.@thunk(error("not implemented")),
81+
ChainRulesCore.DoesNotExist(),
7782
polygamma(m + 1, x),
7883
),
7984
)

test/chainrules.jl

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testset "chainrules" begin
22
Random.seed!(1)
33

4-
@testset "general" begin
4+
@testset "general: single input" begin
55
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
66
test_scalar(erf, x)
77
test_scalar(erfc, x)
@@ -12,9 +12,6 @@
1212
test_scalar(airybi, x)
1313
test_scalar(airybiprime, x)
1414

15-
test_scalar(besselj0, x)
16-
test_scalar(besselj1, x)
17-
1815
test_scalar(erfcx, x)
1916
test_scalar(dawson, x)
2017

@@ -28,37 +25,74 @@
2825
end
2926

3027
if x isa Real && x > 0 || x isa Complex
31-
test_scalar(bessely0, x)
32-
test_scalar(bessely1, x)
3328
test_scalar(gamma, x)
3429
test_scalar(digamma, x)
3530
test_scalar(trigamma, x)
3631
end
3732
end
33+
end
34+
35+
@testset "Bessel functions" begin
36+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
37+
test_scalar(besselj0, x)
38+
test_scalar(besselj1, x)
39+
40+
isreal(x) && x < 0 && continue
41+
42+
test_scalar(bessely0, x)
43+
test_scalar(bessely1, x)
44+
45+
for nu in (-1.5, 2.2, 4.0)
46+
test_frule(besseli, nu, x)
47+
test_rrule(besseli, nu, x)
3848

39-
@testset "beta and logbeta" begin
40-
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
41-
for _x in test_points, _y in test_points
42-
# ensure all complex if any complex for FiniteDifferences
43-
x, y = promote(_x, _y)
44-
test_frule(beta, x, y)
45-
test_rrule(beta, x, y)
49+
test_frule(besselj, nu, x)
50+
test_rrule(besselj, nu, x)
4651

47-
test_frule(logbeta, x, y)
48-
test_rrule(logbeta, x, y)
52+
test_frule(besselk, nu, x)
53+
test_rrule(besselk, nu, x)
54+
55+
test_frule(bessely, nu, x)
56+
test_rrule(bessely, nu, x)
57+
58+
# use complex numbers in `rrule` for FiniteDifferences
59+
test_frule(hankelh1, nu, x)
60+
test_rrule(hankelh1, nu, complex(x))
61+
62+
# use complex numbers in `rrule` for FiniteDifferences
63+
test_frule(hankelh2, nu, x)
64+
test_rrule(hankelh2, nu, complex(x))
4965
end
5066
end
67+
end
5168

52-
@testset "log gamma and co" begin
53-
# It is important that we have negative numbers with both odd and even integer parts
54-
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
55-
isreal(x) && x < 0 && continue
56-
test_scalar(loggamma, x)
69+
@testset "beta and logbeta" begin
70+
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
71+
for _x in test_points, _y in test_points
72+
# ensure all complex if any complex for FiniteDifferences
73+
x, y = promote(_x, _y)
74+
test_frule(beta, x, y)
75+
test_rrule(beta, x, y)
5776

58-
isreal(x) || continue
59-
test_frule(logabsgamma, x)
60-
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
77+
test_frule(logbeta, x, y)
78+
test_rrule(logbeta, x, y)
79+
end
80+
end
81+
82+
@testset "log gamma and co" begin
83+
# It is important that we have negative numbers with both odd and even integer parts
84+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
85+
for m in (0, 1, 2, 3)
86+
test_frule(polygamma, m, x)
87+
test_rrule(polygamma, m, x)
6188
end
89+
90+
isreal(x) && x < 0 && continue
91+
test_scalar(loggamma, x)
92+
93+
isreal(x) || continue
94+
test_frule(logabsgamma, x)
95+
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
6296
end
6397
end
6498
end

0 commit comments

Comments
 (0)