diff --git a/src/chainrules.jl b/src/chainrules.jl index 303ad87c..fa7b5dd3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -16,7 +16,9 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/321 """ ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x)) +ChainRulesCore.@scalar_rule(airyaix(x), airyaiprimex(x) + sqrt(x) * Ω) ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x)) +ChainRulesCore.@scalar_rule(airyaiprimex(x), x * airyaix(x) + sqrt(x) * Ω) ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x)) ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x)) ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x)) @@ -31,12 +33,18 @@ ChainRulesCore.@scalar_rule( ) ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω)) ChainRulesCore.@scalar_rule(digamma(x), trigamma(x)) -ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x)) -ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x)) -ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) -ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x)) -ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) + +# TODO: use `invsqrtπ` if it is added to IrrationalConstants +ChainRulesCore.@scalar_rule(erf(x), (2 * exp(-x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^2)) / sqrtπ)) +ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ) +ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp(Ω^2) / 2))) +ChainRulesCore.@scalar_rule(erfcx(x), 2 * (x * Ω - inv(oftype(Ω, sqrtπ)))) +ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ)) +ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp(Ω^2) / 2)) + ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x)) ChainRulesCore.@scalar_rule( gamma(a, x), @@ -65,7 +73,7 @@ ChainRulesCore.@scalar_rule( ) ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x)) -# binary +# Bessel functions ChainRulesCore.@scalar_rule( besselj(ν, x), ( @@ -94,6 +102,13 @@ ChainRulesCore.@scalar_rule( -(besselk(ν - 1, x) + besselk(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + besselkx(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + -(besselkx(ν - 1, x) + besselkx(ν + 1, x)) / 2 + Ω, + ), +) ChainRulesCore.@scalar_rule( hankelh1(ν, x), ( @@ -101,6 +116,13 @@ ChainRulesCore.@scalar_rule( (hankelh1(ν - 1, x) - hankelh1(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + hankelh1x(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + (hankelh1x(ν - 1, x) - hankelh1x(ν + 1, x)) / 2 - im * Ω, + ), +) ChainRulesCore.@scalar_rule( hankelh2(ν, x), ( @@ -108,6 +130,14 @@ ChainRulesCore.@scalar_rule( (hankelh2(ν - 1, x) - hankelh2(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + hankelh2x(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + (hankelh2x(ν - 1, x) - hankelh2x(ν + 1, x)) / 2 + im * Ω, + ), +) + ChainRulesCore.@scalar_rule( polygamma(m, x), ( @@ -161,5 +191,5 @@ ChainRulesCore.@scalar_rule( ) ) ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) -ChainRulesCore.@scalar_rule(sinint(x), sinc(x / π)) +ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) diff --git a/test/chainrules.jl b/test/chainrules.jl index 5c164b28..d4be8285 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -5,6 +5,7 @@ for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im) test_scalar(erf, x) test_scalar(erfc, x) + test_scalar(erfcx, x) test_scalar(erfi, x) test_scalar(airyai, x) @@ -12,10 +13,12 @@ test_scalar(airybi, x) test_scalar(airybiprime, x) - test_scalar(erfcx, x) test_scalar(dawson, x) if x isa Real + test_scalar(logerfc, x) + test_scalar(logerfcx, x) + test_scalar(invdigamma, x) end @@ -28,6 +31,11 @@ test_scalar(gamma, x) test_scalar(digamma, x) test_scalar(trigamma, x) + + if x isa Real + test_scalar(airyaix, x) + test_scalar(airyaiprimex, x) + end end end end @@ -51,31 +59,38 @@ test_frule(besselk, nu, x) test_rrule(besselk, nu, x) + test_frule(besselkx, nu, x) + test_rrule(besselkx, nu, x) test_frule(bessely, nu, x) test_rrule(bessely, nu, x) - # use complex numbers in `rrule` for FiniteDifferences test_frule(hankelh1, nu, x) - test_rrule(hankelh1, nu, complex(x)) + test_rrule(hankelh1, nu, x) + test_frule(hankelh1x, nu, x) + test_rrule(hankelh1x, nu, x) - # use complex numbers in `rrule` for FiniteDifferences test_frule(hankelh2, nu, x) - test_rrule(hankelh2, nu, complex(x)) + test_rrule(hankelh2, nu, x) + test_frule(hankelh2x, nu, x) + test_rrule(hankelh2x, nu, x) end end end - @testset "beta and logbeta" begin + @testset "erf, beta, and logbeta" begin test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) - for _x in test_points, _y in test_points - # ensure all complex if any complex for FiniteDifferences - x, y = promote(_x, _y) + for x in test_points, y in test_points test_frule(beta, x, y) test_rrule(beta, x, y) test_frule(logbeta, x, y) test_rrule(logbeta, x, y) + + if x isa Real && y isa Real + test_frule(erf, x, y) + test_rrule(erf, x, y) + end end end @@ -91,13 +106,11 @@ isreal(x) && x < 0 && continue test_scalar(loggamma, x) for a in test_points - # ensure all complex if any complex for FiniteDifferences - _a, _x = promote(a, x) - test_frule(gamma, _a, _x; rtol=1e-8) - test_rrule(gamma, _a, _x; rtol=1e-8) + test_frule(gamma, a, x; rtol=1e-8) + test_rrule(gamma, a, x; rtol=1e-8) - test_frule(loggamma, _a, _x) - test_rrule(loggamma, _a, _x) + test_frule(loggamma, a, x) + test_rrule(loggamma, a, x) end isreal(x) || continue @@ -117,14 +130,11 @@ test_scalar(expintx, x) for nu in (-1.5, 2.2, 4.0) - # ensure all complex if any complex for FiniteDifferences - _x, _nu = promote(x, nu) + test_frule(expint, nu, x) + test_rrule(expint, nu, x) - test_frule(expint, _nu, _x) - test_rrule(expint, _nu, _x) - - test_frule(expintx, _nu, _x) - test_rrule(expintx, _nu, _x) + test_frule(expintx, nu, x) + test_rrule(expintx, nu, x) end isreal(x) || continue @@ -133,4 +143,23 @@ test_scalar(cosint, x) end end + + # https://github.com/JuliaMath/SpecialFunctions.jl/issues/307 + @testset "promotions" begin + # one argument + for f in (erf, erfc, logerfc, erfcinv, erfcx, logerfcx, erfi, erfinv, sinint) + _, ẏ = frule((NoTangent(), 1f0), f, 1f0) + @test ẏ isa Float32 + _, back = rrule(f, 1f0) + _, x̄ = back(1f0) + @test x̄ isa Float32 + end + + # two arguments + _, ẏ = frule((NoTangent(), 1f0, 1f0), erf, 1f0, 1f0) + @test ẏ isa Float32 + _, back = rrule(erf, 1f0, 1f0) + _, x̄ = back(1f0) + @test x̄ isa Float32 + end end