Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SpecialFunctions"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.6.2"
version = "1.7.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
33 changes: 30 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/321
"""

ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
ChainRulesCore.@scalar_rule(airyaix(x), airyaiprimex(x) + sqrt(x) * Ω)
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
ChainRulesCore.@scalar_rule(airyaiprimex(x), x * airyaix(x) + sqrt(x) * Ω)
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x))
ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x))
Expand All @@ -31,12 +33,15 @@ ChainRulesCore.@scalar_rule(
)
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x^2))
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x^2))
ChainRulesCore.@scalar_rule(logerfc(x), -(2 / sqrt(π)) * exp(-x^2 - Ω))
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2))
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
ChainRulesCore.@scalar_rule(logerfcx(x), 2 * x - (2 / sqrt(π)) * exp(-Ω))
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x^2))
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp(Ω^2))

ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
ChainRulesCore.@scalar_rule(
gamma(a, x),
Expand Down Expand Up @@ -66,6 +71,7 @@ ChainRulesCore.@scalar_rule(
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))

# binary
ChainRulesCore.@scalar_rule(erf(x, y), (-(2 / sqrt(π)) * exp(-x^2), (2 / sqrt(π)) * exp(-y^2)))
ChainRulesCore.@scalar_rule(
besselj(ν, x),
(
Expand Down Expand Up @@ -94,20 +100,41 @@ ChainRulesCore.@scalar_rule(
-(besselk(ν - 1, x) + besselk(ν + 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
besselkx(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
-(besselkx(ν - 1, x) + besselkx(ν + 1, x)) / 2 + Ω,
),
)
ChainRulesCore.@scalar_rule(
hankelh1(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh1(ν - 1, x) - hankelh1(ν + 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh1x(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh1x(ν - 1, x) - hankelh1x(ν + 1, x)) / 2 - im * Ω,
),
)
ChainRulesCore.@scalar_rule(
hankelh2(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh2(ν - 1, x) - hankelh2(ν + 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh2x(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh2x(ν - 1, x) - hankelh2x(ν + 1, x)) / 2 + im * Ω,
),
)
ChainRulesCore.@scalar_rule(
polygamma(m, x),
(
Expand Down
54 changes: 32 additions & 22 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
test_scalar(erf, x)
test_scalar(erfc, x)
test_scalar(erfcx, x)
test_scalar(erfi, x)

test_scalar(airyai, x)
test_scalar(airyaiprime, x)
test_scalar(airybi, x)
test_scalar(airybiprime, x)

test_scalar(erfcx, x)
test_scalar(dawson, x)

if x isa Real
test_scalar(logerfc, x)
test_scalar(logerfcx, x)

test_scalar(invdigamma, x)
end

Expand All @@ -28,6 +31,11 @@
test_scalar(gamma, x)
test_scalar(digamma, x)
test_scalar(trigamma, x)

if x isa Real
test_scalar(airyaix, x)
test_scalar(airyaiprimex, x)
end
end
end
end
Expand All @@ -51,31 +59,38 @@

test_frule(besselk, nu, x)
test_rrule(besselk, nu, x)
test_frule(besselkx, nu, x)
test_rrule(besselkx, nu, x)

test_frule(bessely, nu, x)
test_rrule(bessely, nu, x)

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh1, nu, x)
test_rrule(hankelh1, nu, complex(x))
test_rrule(hankelh1, nu, x)
test_frule(hankelh1x, nu, x)
test_rrule(hankelh1x, nu, x)

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh2, nu, x)
test_rrule(hankelh2, nu, complex(x))
test_rrule(hankelh2, nu, x)
test_frule(hankelh2x, nu, x)
test_rrule(hankelh2x, nu, x)
end
end
end

@testset "beta and logbeta" begin
@testset "erf, beta, and logbeta" begin
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
for _x in test_points, _y in test_points
# ensure all complex if any complex for FiniteDifferences
x, y = promote(_x, _y)
for x in test_points, y in test_points
test_frule(beta, x, y)
test_rrule(beta, x, y)

test_frule(logbeta, x, y)
test_rrule(logbeta, x, y)

if x isa Real && y isa Real
test_frule(erf, x, y)
test_rrule(erf, x, y)
end
end
end

Expand All @@ -91,13 +106,11 @@
isreal(x) && x < 0 && continue
test_scalar(loggamma, x)
for a in test_points
# ensure all complex if any complex for FiniteDifferences
_a, _x = promote(a, x)
test_frule(gamma, _a, _x; rtol=1e-8)
test_rrule(gamma, _a, _x; rtol=1e-8)
test_frule(gamma, a, x; rtol=1e-8)
test_rrule(gamma, a, x; rtol=1e-8)

test_frule(loggamma, _a, _x)
test_rrule(loggamma, _a, _x)
test_frule(loggamma, a, x)
test_rrule(loggamma, a, x)
end

isreal(x) || continue
Expand All @@ -117,14 +130,11 @@
test_scalar(expintx, x)

for nu in (-1.5, 2.2, 4.0)
# ensure all complex if any complex for FiniteDifferences
_x, _nu = promote(x, nu)

test_frule(expint, _nu, _x)
test_rrule(expint, _nu, _x)
test_frule(expint, nu, x)
test_rrule(expint, nu, x)

test_frule(expintx, _nu, _x)
test_rrule(expintx, _nu, _x)
test_frule(expintx, nu, x)
test_rrule(expintx, nu, x)
end

isreal(x) || continue
Expand Down