Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 297 additions & 0 deletions ext/SpecialFunctionsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,301 @@ function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number)
return Ω, besselyx_pullback
end



## Incomplete beta derivatives via Boik & Robinson-Cox
#
# Reference
# R. J. Boik and J. F. Robinson-Cox (1999).
# "Derivatives of the incomplete beta function."
# Journal of Statistical Software, 3(1).
# URL: https://www.jstatsoft.org/article/view/v003i01
#
# The following implementation computes the regularized incomplete beta
# I_x(a,b) together with its partial derivatives with respect to a, b, and x
# using a continued-fraction representation of ₂F₁ and differentiating through it.
# This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl.

# Generic-typed version for high-precision evaluation
function _beta_inc_grad_boik(a::T, b::T, x::T,
maxapp::Int=200, minapp::Int=3, ϵ::T=convert(T, 1e-12)) where {T<:AbstractFloat}
oneT = one(T); zeroT = zero(T)
if x == oneT
return oneT, zeroT, zeroT, zeroT
elseif x == zeroT
return zeroT, zeroT, zeroT, zeroT
end
dx = exp((a - oneT) * log(x) + (b - oneT) * log1p(-x) - logbeta(a,b))
# swap tails if necessary
p = a; q = b; x₀ = x; swap = false
if x > a / (a + b)
x₀ = oneT - x
p = b
q = a
swap = true
end
Kfun(x::T, p::T, q::T) = exp(p * log(x) + (q - oneT) * log1p(-x) - log(p) - logbeta(p, q))
ffun(x::T, p::T, q::T) = q*x/(p*(oneT - x))
a1fun(p::T, q::T, f::T) = p*f*(q - oneT)/(q*(p + oneT))
anfun(p::T, q::T, f::T, n::Int) = n == 1 ? a1fun(p, q, f) :
p^2 * f^2 * (T(n) - oneT) * (p + q + T(n) - T(2)) * (p + T(n) - oneT) * (q - T(n)) /
(q^2 * (p + T(2n) - T(3)) * (p + T(2n) - T(2))^2 * (p + T(2n) - oneT))
function bnfun(p::T, q::T, f::T, n::Int)
x = T(2)*(p*f + T(2)*q)*T(n)^2 + T(2)*(p*f + T(2)*q)*(p - oneT)*T(n) + p*q*(p - T(2) - p*f)
y = (q * (p + T(2n) - T(2)) * (p + T(2n)))
x/y
end
dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) = K*(log(x) - inv(p) + ψpq - ψp)
dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) = K*(log1p(-x) + ψpq - ψq)
function dK_dpdq(x::T, p::T, q::T)
ψ = digamma(p+q)
Kf = Kfun(x, p, q)
dKdp = dK_dp(x, p, q, Kf, ψ, digamma(p))
dKdq = dK_dq(x, p, q, Kf, ψ, digamma(q))
dKdp, dKdq
end
# a_n derivatives via log-derivative
da1_dp(p::T, q::T, f::T) = -a1fun(p, q, f) / (p + oneT)
function dan_dp(p::T, q::T, f::T, n::Int)
if n == 1
return da1_dp(p, q, f)
end
an = anfun(p, q, f, n)
dlog = inv(p + q + T(n) - T(2)) + inv(p + T(n) - oneT) - inv(p + T(2n) - T(3)) - T(2) * inv(p + T(2n) - T(2)) - inv(p + T(2n) - oneT)
return an * dlog
end
da1_dq(p::T, q::T, f::T) = a1fun(p, q, f) / (q - oneT)
function dan_dq(p::T, q::T, f::T, n::Int)
if n == 1
return da1_dq(p, q, f)
end
an = anfun(p, q, f, n)
dlog = inv(p + q + T(n) - T(2)) + inv(q - T(n))
return an * dlog
end
# b_n derivatives via quotient rule, accounting for f_p=-f/p, f_q=f/q which cancel in N
function dbn_dp(p::T, q::T, f::T, n::Int)
g = p * f + T(2) * q
A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n)
N1 = g * A
N2 = p * q * (p - T(2) - p * f)
N = N1 + N2
D = q * (p + T(2n) - T(2)) * (p + T(2n))
dN1_dp = T(2) * T(n) * g
dN2_dp = q * (T(2) * p - T(2)) - p * q * f
dN_dp = dN1_dp + dN2_dp
dD_dp = q * (T(2) * p + T(4) * T(n) - T(2))
return (dN_dp * D - N * dD_dp) / (D^2)
end
function dbn_dq(p::T, q::T, f::T, n::Int)
g = p * f + T(2) * q
A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n)
N1 = g * A
N2 = p * q * (p - T(2) - p * f)
N = N1 + N2
D = q * (p + T(2n) - T(2)) * (p + T(2n))
g_q = p * (f / q) + T(2)
dN1_dq = g_q * A
dN2_dq = p * (p - T(2) - p * f) - p^2 * f
dN_dq = dN1_dq + dN2_dq
dD_dq = (p + T(2n) - T(2)) * (p + T(2n))
return (dN_dq * D - N * dD_dq) / (D^2)
end
_nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) = begin
an = anfun(p, q, f, n)
bn = bnfun(p, q, f, n)
An = an*App + bn*Ap
Bn = an*Bpp + bn*Bp
An, Bn, an, bn
end
_dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) = dan * Xpp + an * dXpp + dbn * Xp + bn * dXp

# compute once
K = Kfun(x₀, p, q)
dK_dp_val, dK_dq_val = dK_dpdq(x₀, p, q)
f = ffun(x₀, p, q)
App = oneT; Ap = oneT; Bpp = zeroT; Bp = oneT
dApp_dp = zeroT; dBpp_dp = zeroT; dAp_dp = zeroT; dBp_dp = zeroT
dApp_dq = zeroT; dBpp_dq = zeroT; dAp_dq = zeroT; dBp_dq = zeroT
dI_dp = T(NaN); dI_dq = T(NaN); Ixpq = T(NaN); Ixpqn = T(NaN); dI_dp_prev = T(NaN); dI_dq_prev = T(NaN)
for n=1:maxapp
An, Bn, an, bn = _nextapp(f, p, q, n, App, Ap, Bpp, Bp)
dan = dan_dp(p, q, f, n); dbn = dbn_dp(p, q, f, n)
dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp)
dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp)
dan = dan_dq(p, q, f, n); dbn = dbn_dq(p, q, f, n)
dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq)
dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq)
# normalize states to control growth/underflow (scale-invariant)
s = maximum((abs(An), abs(Bn), abs(Ap), abs(Bp), abs(App), abs(Bpp)))
if isfinite(s) && s > zeroT
invs = inv(s)
An *= invs; Bn *= invs
Ap *= invs; Bp *= invs
App *= invs; Bpp *= invs
dAn_dp *= invs; dBn_dp *= invs
dAn_dq *= invs; dBn_dq *= invs
dAp_dp *= invs; dBp_dp *= invs
dApp_dp *= invs; dBpp_dp *= invs
dAp_dq *= invs; dBp_dq *= invs
dApp_dq *= invs; dBpp_dq *= invs
end
Cn = An/Bn
dI_dp = dK_dp_val * Cn + K * (inv(Bn) * dAn_dp - (An/(Bn^2)) * dBn_dp)
dI_dq = dK_dq_val * Cn + K * (inv(Bn) * dAn_dq - (An/(Bn^2)) * dBn_dq)
Ixpqn = K * Cn
if n >= minapp
denomI = max(abs(Ixpqn), abs(Ixpq), eps(T))
denomp = max(abs(dI_dp), abs(dI_dp_prev), eps(T))
denomq = max(abs(dI_dq), abs(dI_dq_prev), eps(T))
rI = abs(Ixpqn - Ixpq) / denomI
rp = abs(dI_dp - dI_dp_prev) / denomp
rq = abs(dI_dq - dI_dq_prev) / denomq
if max(rI, rp, rq) < ϵ
break
end
end
Ixpq = Ixpqn
dI_dp_prev = dI_dp
dI_dq_prev = dI_dq
App = Ap; Bpp = Bp; Ap = An; Bp = Bn
dApp_dp = dAp_dp; dApp_dq = dAp_dq; dBpp_dp = dBp_dp; dBpp_dq = dBp_dq
dAp_dp = dAn_dp; dAp_dq = dAn_dq; dBp_dp = dBn_dp; dBp_dq = dBn_dq
end
if swap
return oneT - Ixpqn, -dI_dq, -dI_dp, dx
else
return Ixpqn, dI_dp, dI_dq, dx
end
end

# Generic wrapper preserving the previous interface/name
function _ibeta_grad_splus(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat}
tol = min(err, T(1e-14))
maxit = max(1000, maxapp)
minit = max(5, minapp)
I, dIa, dIb, dIx = _beta_inc_grad_boik(a, b, x, maxit, minit, tol)
return I, dIa, dIb, dIx
end



# Incomplete beta: beta_inc(a,b,x) -> (p, q) with q=1-p
function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, b::Number, x::Number)
# primal
p, q = beta_inc(a, b, x)
# derivatives
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)))
_, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_
ΔaT::T = Δa isa Real ? T(Δa) : zero(T)
ΔbT::T = Δb isa Real ? T(Δb) : zero(T)
ΔxT::T = Δx isa Real ? T(Δx) : zero(T)
Δp = dIa * ΔaT + dIb * ΔbT + dIx * ΔxT
Δq = -Δp
Tout = typeof((p, q))
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
end

function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number)
p, q = beta_inc(a, b, x)
Ta = ChainRulesCore.ProjectTo(a)
Tb = ChainRulesCore.ProjectTo(b)
Tx = ChainRulesCore.ProjectTo(x)
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)))
_, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_
function beta_inc_pullback(Δ)
Δp, Δq = Δ
s = Δp - Δq # because q = 1 - p
ā = Ta(s * dIa)
b̄ = Tb(s * dIb)
x̄ = Tx(s * dIx)
return ChainRulesCore.NoTangent(), ā, b̄, x̄
end
return (p, q), beta_inc_pullback
end
function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number)
p, q = beta_inc(a, b, x, y)
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y)))
_, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_
ΔaT::T = Δa isa Real ? T(Δa) : zero(T)
ΔbT::T = Δb isa Real ? T(Δb) : zero(T)
ΔxT::T = Δx isa Real ? T(Δx) : zero(T)
ΔyT::T = Δy isa Real ? T(Δy) : zero(T)
Δp = dIa * ΔaT + dIb * ΔbT + dIx * (ΔxT - ΔyT)
Δq = -Δp
Tout = typeof((p, q))
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
end

function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number)
p, q = beta_inc(a, b, x, y)
Ta = ChainRulesCore.ProjectTo(a)
Tb = ChainRulesCore.ProjectTo(b)
Tx = ChainRulesCore.ProjectTo(x)
Ty = ChainRulesCore.ProjectTo(y)
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y)))
_, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_
function beta_inc_pullback(Δ)
Δp, Δq = Δ
s = Δp - Δq
ā = Ta(s * dIa)
b̄ = Tb(s * dIb)
x̄ = Tx(s * dIx)
ȳ = Ty(-s * dIx)
return ChainRulesCore.NoTangent(), ā, b̄, x̄, ȳ
end
return (p, q), beta_inc_pullback
end

# Inverse incomplete beta: beta_inc_inv(a,b,p) -> (x, 1-x)
function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Number, b::Number, p::Number)
x, y = beta_inc_inv(a, b, p)
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p)))
# Implicit differentiation at solved x: I_x(a,b) = p
_, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_
# ∂I/∂x at solved x via stable log-space expression
dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b)))))
inv_dIx::T = inv(dIx_acc)
dx_da::T = -dIa * inv_dIx
dx_db::T = -dIb * inv_dIx
dx_dp::T = inv_dIx
ΔaT::T = Δa isa Real ? T(Δa) : zero(T)
ΔbT::T = Δb isa Real ? T(Δb) : zero(T)
ΔpT::T = Δp isa Real ? T(Δp) : zero(T)
Δx = dx_da * ΔaT + dx_db * ΔbT + dx_dp * ΔpT
Δy = -Δx
Tout = typeof((x, y))
return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy)
end

function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::Number)
x, y = beta_inc_inv(a, b, p)
Ta = ChainRulesCore.ProjectTo(a)
Tb = ChainRulesCore.ProjectTo(b)
Tp = ChainRulesCore.ProjectTo(p)
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p)))
_, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x))
dIa::T = dIa_; dIb::T = dIb_
# ∂I/∂x at solved x via stable log-space expression
dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b)))))
inv_dIx::T = inv(dIx_acc)
dx_da::T = -dIa * inv_dIx
dx_db::T = -dIb * inv_dIx
dx_dp::T = inv_dIx
function beta_inc_inv_pullback(Δ)
Δx, Δy = Δ
s = Δx - Δy
ā = Ta(s * dx_da)
b̄ = Tb(s * dx_db)
p̄ = Tp(s * dx_dp)
return ChainRulesCore.NoTangent(), ā, b̄, p̄
end
return (x, y), beta_inc_inv_pullback
end

end # module
Loading