Skip to content

Commit 9f230e6

Browse files
authored
Add differentiation rules from ChainRules (#238)
* Add differentiation rules from ChainRules * Allow test failures on Julia nightly * Allow failures (correctly?) * Try to avoid spurious test failures by setting seed * Throw error instead of returning NaN * Fix test errors
1 parent e8a1e5c commit 9f230e6

File tree

6 files changed

+184
-2
lines changed

6 files changed

+184
-2
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ julia:
55
- 1.3
66
- 1
77
- nightly
8+
matrix:
9+
allow_failures:
10+
- julia: nightly
811
notifications:
912
email: false
1013

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
33
version = "1.1"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
78

89
[compat]
10+
ChainRulesCore = "0.9"
911
OpenSpecFun_jll = "0.5.3"
1012
julia = "1.3"
1113

1214
[extras]
15+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1317
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1418

1519
[targets]
16-
test = ["Test"]
20+
test = ["ChainRulesTestUtils", "Random", "Test"]

src/SpecialFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module SpecialFunctions
22

3+
import ChainRulesCore
4+
35
using OpenSpecFun_jll
46

57
export
@@ -71,6 +73,7 @@ include("gamma.jl")
7173
include("gamma_inc.jl")
7274
include("betanc.jl")
7375
include("beta_inc.jl")
76+
include("chainrules.jl")
7477
include("deprecated.jl")
7578

7679
for f in (:digamma, :erf, :erfc, :erfcinv, :erfcx, :erfi, :erfinv, :logerfc, :logerfcx,

src/chainrules.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
2+
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
3+
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
4+
ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x))
5+
ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x))
6+
ChainRulesCore.@scalar_rule(
7+
besselj1(x),
8+
(besselj0(x) - besselj(2, x)) / 2,
9+
)
10+
ChainRulesCore.@scalar_rule(bessely0(x), -bessely1(x))
11+
ChainRulesCore.@scalar_rule(
12+
bessely1(x),
13+
(bessely0(x) - bessely(2, x)) / 2,
14+
)
15+
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
16+
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
17+
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x))
18+
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x))
19+
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp^2))
20+
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
21+
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
22+
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))
23+
ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
24+
ChainRulesCore.@scalar_rule(
25+
invdigamma(x),
26+
inv(trigamma(invdigamma(x))),
27+
)
28+
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))
29+
30+
# binary
31+
ChainRulesCore.@scalar_rule(
32+
besselj(ν, x),
33+
(
34+
ChainRulesCore.@thunk(error("not implemented")),
35+
(besselj- 1, x) - besselj+ 1, x)) / 2
36+
),
37+
)
38+
ChainRulesCore.@scalar_rule(
39+
besseli(ν, x),
40+
(
41+
ChainRulesCore.@thunk(error("not implemented")),
42+
(besseli- 1, x) + besseli+ 1, x)) / 2,
43+
),
44+
)
45+
ChainRulesCore.@scalar_rule(
46+
bessely(ν, x),
47+
(
48+
ChainRulesCore.@thunk(error("not implemented")),
49+
(bessely- 1, x) - bessely+ 1, x)) / 2,
50+
),
51+
)
52+
ChainRulesCore.@scalar_rule(
53+
besselk(ν, x),
54+
(
55+
ChainRulesCore.@thunk(error("not implemented")),
56+
-(besselk- 1, x) + besselk+ 1, x)) / 2,
57+
),
58+
)
59+
ChainRulesCore.@scalar_rule(
60+
hankelh1(ν, x),
61+
(
62+
ChainRulesCore.@thunk(error("not implemented")),
63+
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
64+
),
65+
)
66+
ChainRulesCore.@scalar_rule(
67+
hankelh2(ν, x),
68+
(
69+
ChainRulesCore.@thunk(error("not implemented")),
70+
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
71+
),
72+
)
73+
ChainRulesCore.@scalar_rule(
74+
polygamma(m, x),
75+
(
76+
ChainRulesCore.@thunk(error("not implemented")),
77+
polygamma(m + 1, x),
78+
),
79+
)
80+
# todo: setup for common expr
81+
ChainRulesCore.@scalar_rule(
82+
beta(a, b),
83+
*(digamma(a) - digamma(a + b)),
84+
Ω*(digamma(b) - digamma(a + b)),)
85+
)
86+
ChainRulesCore.@scalar_rule(
87+
logbeta(a, b),
88+
(digamma(a) - digamma(a + b),
89+
digamma(b) - digamma(a + b),)
90+
)
91+
92+
# actually is the absolute value of the logorithm of gamma paired with sign gamma
93+
ChainRulesCore.@scalar_rule(logabsgamma(x), digamma(x), ChainRulesCore.Zero())
94+
95+
ChainRulesCore.@scalar_rule(loggamma(x), digamma(x))

test/chainrules.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
@testset "chainrules" begin
2+
Random.seed!(1)
3+
4+
@testset "general" begin
5+
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
6+
test_scalar(erf, x)
7+
test_scalar(erfc, x)
8+
test_scalar(erfi, x)
9+
10+
test_scalar(airyai, x)
11+
test_scalar(airyaiprime, x)
12+
test_scalar(airybi, x)
13+
test_scalar(airybiprime, x)
14+
15+
test_scalar(besselj0, x)
16+
test_scalar(besselj1, x)
17+
18+
test_scalar(erfcx, x)
19+
test_scalar(dawson, x)
20+
21+
if x isa Real
22+
test_scalar(invdigamma, x)
23+
end
24+
25+
if x isa Real && 0 < x < 1
26+
test_scalar(erfinv, x)
27+
test_scalar(erfcinv, x)
28+
end
29+
30+
if x isa Real && x > 0 || x isa Complex
31+
test_scalar(bessely0, x)
32+
test_scalar(bessely1, x)
33+
test_scalar(gamma, x)
34+
test_scalar(digamma, x)
35+
test_scalar(trigamma, x)
36+
end
37+
end
38+
39+
@testset "beta and logbeta" begin
40+
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
41+
for _x in test_points, _y in test_points
42+
# ensure all complex if any complex for FiniteDifferences
43+
x, y = promote(_x, _y)
44+
T = typeof(x)
45+
46+
Δx, x̄ = randn(T, 2)
47+
Δy, ȳ = randn(T, 2)
48+
Δz = randn(T)
49+
50+
frule_test(beta, (x, Δx), (y, Δy))
51+
rrule_test(beta, Δz, (x, x̄), (y, ȳ))
52+
53+
frule_test(logbeta, (x, Δx), (y, Δy))
54+
rrule_test(logbeta, Δz, (x, x̄), (y, ȳ))
55+
end
56+
end
57+
58+
@testset "log gamma and co" begin
59+
# It is important that we have negative numbers with both odd and even integer parts
60+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
61+
isreal(x) && x < 0 && continue
62+
test_scalar(loggamma, x)
63+
64+
isreal(x) || continue
65+
66+
Δx, x̄ = randn(2)
67+
Δz = (randn(), randn())
68+
69+
frule_test(logabsgamma, (x, Δx))
70+
rrule_test(logabsgamma, Δz, (x, x̄))
71+
end
72+
end
73+
end
74+
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# This file contains code that was formerly a part of Julia. License is MIT: http://julialang.org/license
22

33
using SpecialFunctions
4+
using ChainRulesTestUtils
5+
using Random
46
using Test
57
using Base.MathConstants: γ
68

@@ -28,7 +30,8 @@ tests = [
2830
"gamma_inc",
2931
"gamma",
3032
"sincosint",
31-
"other_tests"
33+
"other_tests",
34+
"chainrules"
3235
]
3336

3437
const testdir = dirname(@__FILE__)

0 commit comments

Comments
 (0)