|
11 | 11 | (c, -c, z), |
12 | 12 | ) |
13 | 13 |
|
| 14 | +# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106 |
| 15 | + |
14 | 16 | ## Beta ## |
15 | 17 |
|
16 | 18 | @scalar_rule( |
17 | 19 | betalogpdf(α::Real, β::Real, x::Number), |
18 | | - @setup(di = digamma(α + β)), |
| 20 | + @setup(z = digamma(α + β)), |
19 | 21 | ( |
20 | | - @thunk(log(x) - digamma(α) + di), |
21 | | - @thunk(log(1 - x) - digamma(β) + di), |
22 | | - @thunk((α - 1)/x + (1 - β)/(1 - x)), |
| 22 | + log(x) + z - digamma(α), |
| 23 | + log1p(-x) + z - digamma(β), |
| 24 | + (α - 1) / x + (1 - β) / (1 - x), |
23 | 25 | ), |
24 | 26 | ) |
25 | 27 |
|
26 | 28 | ## Gamma ## |
27 | 29 |
|
28 | 30 | @scalar_rule( |
29 | 31 | gammalogpdf(k::Real, θ::Real, x::Number), |
| 32 | + @setup( |
| 33 | + invθ = inv(θ), |
| 34 | + xoθ = invθ * x, |
| 35 | + z = xoθ - k, |
| 36 | + ), |
30 | 37 | ( |
31 | | - @thunk(-digamma(k) - log(θ) + log(x)), |
32 | | - @thunk(-k/θ + x/θ^2), |
33 | | - @thunk((k - 1)/x - 1/θ), |
| 38 | + log(xoθ) - digamma(k), |
| 39 | + invθ * z, |
| 40 | + - (1 + z) / x, |
34 | 41 | ), |
35 | 42 | ) |
36 | 43 |
|
37 | 44 | ## Chisq ## |
38 | 45 |
|
39 | 46 | @scalar_rule( |
40 | 47 | chisqlogpdf(k::Real, x::Number), |
41 | | - @setup(ko2 = k / 2), |
42 | | - (@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)), |
| 48 | + @setup(hk = k / 2), |
| 49 | + ( |
| 50 | + (log(x) - logtwo - digamma(hk)) / 2, |
| 51 | + (hk - 1) / x - one(hk) / 2, |
| 52 | + ), |
43 | 53 | ) |
44 | 54 |
|
45 | 55 | ## FDist ## |
46 | 56 |
|
47 | 57 | @scalar_rule( |
48 | | - fdistlogpdf(v1::Real, v2::Real, x::Number), |
| 58 | + fdistlogpdf(ν1::Real, ν2::Real, x::Number), |
49 | 59 | @setup( |
50 | | - temp1 = v1 * x + v2, |
51 | | - temp2 = log(temp1), |
52 | | - vsum = v1 + v2, |
53 | | - temp3 = vsum / temp1, |
54 | | - temp4 = digamma(vsum / 2), |
| 60 | + xν1 = x * ν1, |
| 61 | + temp1 = xν1 + ν2, |
| 62 | + a = (x - 1) / temp1, |
| 63 | + νsum = ν1 + ν2, |
| 64 | + di = digamma(νsum / 2), |
55 | 65 | ), |
56 | 66 | ( |
57 | | - @thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2), |
58 | | - @thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2), |
59 | | - @thunk(v1 / 2 * (1 / x - temp3) - 1 / x), |
| 67 | + (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, |
| 68 | + (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, |
| 69 | + ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, |
60 | 70 | ), |
61 | 71 | ) |
62 | 72 |
|
63 | 73 | ## TDist ## |
64 | 74 |
|
65 | 75 | @scalar_rule( |
66 | | - tdistlogpdf(v::Real, x::Number), |
| 76 | + tdistlogpdf(ν::Real, x::Number), |
| 77 | + @setup( |
| 78 | + νp1 = ν + 1, |
| 79 | + xsq = x^2, |
| 80 | + invν = inv(ν), |
| 81 | + a = xsq * invν, |
| 82 | + b = νp1 / (ν + xsq), |
| 83 | + ), |
67 | 84 | ( |
68 | | - @thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2), |
69 | | - @thunk(-x * (v + 1) / (v + x^2)), |
70 | | - ) |
| 85 | + (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, |
| 86 | + - x * b, |
| 87 | + ), |
71 | 88 | ) |
72 | 89 |
|
73 | 90 | ## Binomial ## |
74 | 91 |
|
75 | 92 | @scalar_rule( |
76 | | - binomlogpdf(n::Int, p::Real, x::Int), |
77 | | - (DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()), |
| 93 | + binomlogpdf(n::Real, p::Real, k::Real), |
| 94 | + @setup(z = digamma(n - k + 1)), |
| 95 | + ( |
| 96 | + digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), |
| 97 | + (k / p - n) / (1 - p), |
| 98 | + z - digamma(k + 1) + logit(p), |
| 99 | + ), |
78 | 100 | ) |
79 | 101 |
|
80 | 102 | ## Poisson ## |
81 | 103 |
|
82 | 104 | @scalar_rule( |
83 | | - poislogpdf(v::Real, x::Int), |
84 | | - (x / v - 1, DoesNotExist()), |
| 105 | + poislogpdf(λ::Number, x::Number), |
| 106 | + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), |
| 107 | +) |
| 108 | + |
| 109 | +## PoissonBinomial |
| 110 | + |
| 111 | +function ChainRulesCore.rrule( |
| 112 | + ::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real} |
85 | 113 | ) |
| 114 | + y = Distributions.poissonbinomial_pdf_fft(p) |
| 115 | + A = poissonbinomial_partialderivatives(p) |
| 116 | + function poissonbinomial_pdf_fft_pullback(Δy) |
| 117 | + p̄ = InplaceableThunk( |
| 118 | + @thunk(A * Δy), |
| 119 | + Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), |
| 120 | + ) |
| 121 | + return (NO_FIELDS, p̄) |
| 122 | + end |
| 123 | + return y, poissonbinomial_pdf_fft_pullback |
| 124 | +end |
| 125 | + |
| 126 | +if isdefined(Distributions, :poissonbinomial_pdf) |
| 127 | + function ChainRulesCore.rrule( |
| 128 | + ::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real} |
| 129 | + ) |
| 130 | + y = Distributions.poissonbinomial_pdf(p) |
| 131 | + A = poissonbinomial_partialderivatives(p) |
| 132 | + function poissonbinomial_pdf_pullback(Δy) |
| 133 | + p̄ = InplaceableThunk( |
| 134 | + @thunk(A * Δy), |
| 135 | + Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), |
| 136 | + ) |
| 137 | + return (NO_FIELDS, p̄) |
| 138 | + end |
| 139 | + return y, poissonbinomial_pdf_pullback |
| 140 | + end |
| 141 | +end |
0 commit comments