Skip to content

Commit e59e092

Browse files
authored
Merge pull request #305 from devmotion/dw/chainrules_gamma
Add ChainRules definitions for `gamma(a, x)`, `loggamma(a, x)`, and `gamma_inc`
2 parents a48ba24 + 98bc7d0 commit e59e092

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

src/chainrules.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ derivatives of Bessel functions with respect to the order are not implemented cu
33
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
44
"""
55

6+
const INCOMPLETE_GAMMA_INFO = """
7+
derivatives of the incomplete Gamma functions with respect to parameter `a` are not
8+
implemented currently:
9+
https://github.com/JuliaMath/SpecialFunctions.jl/issues/317
10+
"""
11+
612
ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
713
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
814
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
@@ -26,6 +32,27 @@ ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
2632
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
2733
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))
2834
ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
35+
ChainRulesCore.@scalar_rule(
36+
gamma(a, x),
37+
(
38+
ChainRulesCore.@not_implemented(INCOMPLETE_GAMMA_INFO),
39+
- exp(-x) * x^(a - 1),
40+
),
41+
)
42+
ChainRulesCore.@scalar_rule(
43+
gamma_inc(a, x, IND),
44+
@setup(z = exp(-x) * x^(a - 1) / gamma(a)),
45+
(
46+
ChainRulesCore.@not_implemented(INCOMPLETE_GAMMA_INFO),
47+
z,
48+
ChainRulesCore.DoesNotExist(),
49+
),
50+
(
51+
ChainRulesCore.@not_implemented(INCOMPLETE_GAMMA_INFO),
52+
-z,
53+
ChainRulesCore.DoesNotExist(),
54+
),
55+
)
2956
ChainRulesCore.@scalar_rule(
3057
invdigamma(x),
3158
inv(trigamma(invdigamma(x))),
@@ -98,3 +125,10 @@ ChainRulesCore.@scalar_rule(
98125
ChainRulesCore.@scalar_rule(logabsgamma(x), digamma(x), ChainRulesCore.Zero())
99126

100127
ChainRulesCore.@scalar_rule(loggamma(x), digamma(x))
128+
ChainRulesCore.@scalar_rule(
129+
loggamma(a, x),
130+
(
131+
ChainRulesCore.@not_implemented(INCOMPLETE_GAMMA_INFO),
132+
-exp(- (x + Ω)) * x^(a - 1),
133+
)
134+
)

test/chainrules.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,33 @@
8181

8282
@testset "log gamma and co" begin
8383
# It is important that we have negative numbers with both odd and even integer parts
84-
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
84+
test_points = (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
85+
for x in test_points
8586
for m in (0, 1, 2, 3)
8687
test_frule(polygamma, m, x)
8788
test_rrule(polygamma, m, x)
8889
end
8990

9091
isreal(x) && x < 0 && continue
9192
test_scalar(loggamma, x)
93+
for a in test_points
94+
# ensure all complex if any complex for FiniteDifferences
95+
_a, _x = promote(a, x)
96+
test_frule(gamma, _a, _x; rtol=1e-8)
97+
test_rrule(gamma, _a, _x; rtol=1e-8)
98+
99+
test_frule(loggamma, _a, _x)
100+
test_rrule(loggamma, _a, _x)
101+
end
92102

93103
isreal(x) || continue
94104
test_frule(logabsgamma, x)
95105
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
106+
for a in test_points
107+
isreal(a) && a > 0 || continue
108+
test_frule(gamma_inc, a, x, 0)
109+
test_rrule(gamma_inc, a, x, 0; output_tangent=(randn(), randn()))
110+
end
96111
end
97112
end
98113
end

0 commit comments

Comments
 (0)