diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index 169ee237..9db2ef26 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -300,4 +300,431 @@ function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) return Ω, besselyx_pullback end + + +## Incomplete beta derivatives via Boik & Robinson-Cox +# +# Reference +# R. J. Boik and J. F. Robinson-Cox (1999). +# "Derivatives of the incomplete beta function." +# Journal of Statistical Software, 3(1). +# URL: https://www.jstatsoft.org/article/view/v003i01 +# +# The following implementation computes the regularized incomplete beta +# I_x(a,b) together with its partial derivatives with respect to a, b, and x +# using a continued-fraction representation of ₂F₁ and differentiating through it. +# This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl. + +# Generic-typed helpers used by the continued-fraction evaluation of I_x(a,b) +# and its partial derivatives. These implement the scalar prefactor K(x;p,q), +# the auxiliary variable f, the continued-fraction coefficients a_n, b_n, and +# their partial derivatives w.r.t. p (≡ a) and q (≡ b). See Boik & Robinson-Cox (1999). + +function _Kfun(x::T, p::T, q::T) where {T<:AbstractFloat} + # K(x;p,q) = x^p (1-x)^{q-1} / (p * B(p,q)) computed in log-space for stability + return exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - logbeta(p, q)) +end + +function _ffun(x::T, p::T, q::T) where {T<:AbstractFloat} + # f = q x / (p (1-x)) — convenience variable appearing in CF coefficients + return q * x / (p * (1 - x)) +end + +function _a1fun(p::T, q::T, f::T) where {T<:AbstractFloat} + # a₁ coefficient of the continued fraction for ₂F₁ representation + return p * f * (q - 1) / (q * (p + 1)) +end + +function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # a_n coefficient (n ≥ 1) of the continued fraction for ₂F₁ in terms of p=a, q=b, f. + # For n=1, falls back to a₁; for n≥2 uses the closed-form product from the Gauss CF. + n == 1 && return _a1fun(p, q, f) + return p^2 * f^2 * (n - 1) * (p + q + n - 2) * (p + n - 1) * (q - n) / + (q^2 * (p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1)) +end + +function _bnfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # b_n coefficient (n ≥ 1) of the continued fraction. Derived for the same CF. + x = 2 * (p * f + 2 * q) * n^2 + 2 * (p * f + 2 * q) * (p - 1) * n + p * q * (p - 2 - p * f) + y = q * (p + 2*n - 2) * (p + 2*n) + return x / y +end + +function _dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) where {T<:AbstractFloat} + # ∂K/∂p using digamma identities: d/dp log B(p,q) = ψ(p) - ψ(p+q) + return K * (log(x) - inv(p) + ψpq - ψp) +end + +function _dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) where {T<:AbstractFloat} + # ∂K/∂q using identical pattern + K * (log1p(-x) + ψpq - ψq) +end + +function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat} + # Convenience: compute (∂K/∂p, ∂K/∂q) together with shared ψ(p+q) + ψ = digamma(p + q) + Kf = _Kfun(x, p, q) + dKdp = _dK_dp(x, p, q, Kf, ψ, digamma(p)) + dKdq = _dK_dq(x, p, q, Kf, ψ, digamma(q)) + return dKdp, dKdq +end + +function _da1_dp(p::T, q::T, f::T) where {T<:AbstractFloat} + # ∂a₁/∂p from the closed form of a₁ + return - _a1fun(p, q, f) / (p + 1) +end + +function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂a_n/∂p via log-derivative: d a_n = a_n * d log a_n; for n=1, uses ∂a₁/∂p + if n == 1 + return _da1_dp(p, q, f) + end + an = _anfun(p, q, f, n) + dlog = inv(p + q + n - 2) + inv(p + n - 1) - inv(p + 2*n - 3) - 2 * inv(p + 2*n - 2) - inv(p + 2*n - 1) + return an * dlog +end + +function _da1_dq(p::T, q::T, f::T) where {T<:AbstractFloat} + # ∂a₁/∂q + return _a1fun(p, q, f) / (q - 1) +end + + +function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂a_n/∂q avoiding the removable singularity at q ≈ n for integer q. + # For n=1, defer to the specific a₁ derivative. + if n == 1 + return _da1_dq(p, q, f) + end + # Use the simplified closed-form of a_n that eliminates explicit q^2 via f: + # a_n = (x/(1-x))^2 * (n-1) * (p+n-1) * (p+q+n-2) * (q-n) / D(p,n) + # where D(p,n) = (p+2n-3)*(p+2n-2)^2*(p+2n-1) and (x/(1-x)) = p*f/q. + # Differentiate only the q-dependent factor G(q) = (p+q+n-2)*(q-n): + # dG/dq = (q-n) + (p+q+n-2) = p + 2q - 2. + + # This is equivalent to + # return _anfun(p,q,f,n) * (inv(p+q+n-2) + inv(q-n)) + # but more precise. + + pfq = (p * f) / q + C = (pfq * pfq) * (n - 1) * (p + n - 1) / + ((p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1)) + return C * (p + 2*q - 2) +end + +function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂b_n/∂p via quotient rule on b_n = N/D. + # Note the internal dependence f(p,q)=q x/(p(1-x)) — terms cancel in N as per derivation. + g = p * f + 2 * q + A = 2 * n^2 + 2 * (p - 1) * n + N1 = g * A + N2 = p * q * (p - 2 - p * f) + N = N1 + N2 + D = q * (p + 2*n - 2) * (p + 2*n) + dN1_dp = 2 * n * g + dN2_dp = q * (2 * p - 2) - p * q * f + dN_dp = dN1_dp + dN2_dp + dD_dp = q * (2 * p + 4 * n - 2) + return (dN_dp * D - N * dD_dp) / (D^2) +end + +function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂b_n/∂q similarly via quotient rule + g = p * f + 2 * q + A = 2 * n^2 + 2 * (p - 1) * n + N1 = g * A + N2 = p * q * (p - 2 - p * f) + N = N1 + N2 + D = q * (p + 2*n - 2) * (p + 2*n) + g_q = p * (f / q) + 2 + dN1_dq = g_q * A + dN2_dq = p * (p - 2 - p * f) - p^2 * f + dN_dq = dN1_dq + dN2_dq + dD_dq = (p + 2*n - 2) * (p + 2*n) + return (dN_dq * D - N * dD_dq) / (D^2) +end + +function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where {T<:AbstractFloat} + # One step of the continuant recurrences: + # A_n = a_n A_{n-2} + b_n A_{n-1} + # B_n = a_n B_{n-2} + b_n B_{n-1} + an = _anfun(p, q, f, n) + bn = _bnfun(p, q, f, n) + An = an * App + bn * Ap + Bn = an * Bpp + bn * Bp + return An, Bn, an, bn +end + +function _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) where {T<:AbstractFloat} + # Derivative propagation for the same recurrences (X∈{A,B}) + return dan * Xpp + an * dXpp + dbn * Xp + bn * dXp +end + +function _beta_inc_grad(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} + # Compute I_x(a,b) and partial derivatives (∂I/∂a, ∂I/∂b, ∂I/∂x) + # using a differentiated continued fraction with convergence control. + oneT = one(T) + zeroT = zero(T) + + # 1) Boundary cases for x + x == oneT && return oneT, zeroT, zeroT, zeroT + x == zeroT && return zeroT, zeroT, zeroT, zeroT + + # 2) Clamp iteration/tolerance parameters to robust defaults + ϵ = min(err, T(1e-14)) + maxapp = max(1000, maxapp) + minapp = max(5, minapp) + + # 3) Non-boundary path: precompute ∂I/∂x at original (a,b,x) via stable log form + dx = exp((a - oneT) * log(x) + (b - oneT) * log1p(-x) - logbeta(a,b)) + + # 4) Optional tail-swap for symmetry and improved CF convergence: + # if x > a/(a+b), evaluate at (p,q,x₀) = (b,a,1-x) and swap back at the end. + p = a + q = b + x₀ = x + swap = false + if x > a / (a + b) + x₀ = oneT - x + p = b + q = a + swap = true + end + + # 5) Initialize CF state and derivatives + K = _Kfun(x₀, p, q) + dK_dp_val, dK_dq_val = _dK_dpdq(x₀, p, q) + f = _ffun(x₀, p, q) + App = oneT + Ap = oneT + Bpp = zeroT + Bp = oneT + dApp_dp = zeroT + dBpp_dp = zeroT + dAp_dp = zeroT + dBp_dp = zeroT + dApp_dq = zeroT + dBpp_dq = zeroT + dAp_dq = zeroT + dBp_dq = zeroT + dI_dp = T(NaN) + dI_dq = T(NaN) + Ixpq = T(NaN) + Ixpqn = T(NaN) + dI_dp_prev = T(NaN) + dI_dq_prev = T(NaN) + + # 6) Main CF loop (n from 1): update continuants, scale, form current approximant Cn=A_n/B_n + # and its derivatives to update I and ∂I/∂(p,q). Stop on relative convergence of all. + for n=1:maxapp + + # Update continuants. + An, Bn, an, bn = _nextapp(f, p, q, n, App, Ap, Bpp, Bp) + dan = _dan_dp(p, q, f, n) + dbn = _dbn_dp(p, q, f, n) + dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp) + dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp) + dan = _dan_dq(p, q, f, n) + dbn = _dbn_dq(p, q, f, n) + dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq) + dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq) + + # Normalize states to control growth/underflow (scale-invariant transform) + s = maximum((abs(An), abs(Bn), abs(Ap), abs(Bp), abs(App), abs(Bpp))) + if isfinite(s) && s > zeroT + invs = inv(s) + An *= invs + Bn *= invs + Ap *= invs + Bp *= invs + App *= invs + Bpp *= invs + dAn_dp *= invs + dBn_dp *= invs + dAn_dq *= invs + dBn_dq *= invs + dAp_dp *= invs + dBp_dp *= invs + dApp_dp *= invs + dBpp_dp *= invs + dAp_dq *= invs + dBp_dq *= invs + dApp_dq *= invs + dBpp_dq *= invs + end + + # Form current approximant Cn=A_n/B_n and its derivatives. + # Guard against tiny/zero Bn to avoid NaNs/Inf in divisions. + tiny = sqrt(eps(T)) + absBn = abs(Bn) + sgnBn = ifelse(Bn >= zeroT, oneT, -oneT) + invBn = absBn > tiny && isfinite(absBn) ? inv(Bn) : inv(sgnBn * tiny) + Cn = An * invBn + invBn2 = invBn * invBn + dI_dp = dK_dp_val * Cn + K * (invBn * dAn_dp - (An * invBn2) * dBn_dp) + dI_dq = dK_dq_val * Cn + K * (invBn * dAn_dq - (An * invBn2) * dBn_dq) + Ixpqn = K * Cn + + # Decide convergence: + if n >= minapp + # Relative convergence for I, ∂I/∂p, ∂I/∂q (guards against tiny denominators) + denomI = max(abs(Ixpqn), abs(Ixpq), eps(T)) + denomp = max(abs(dI_dp), abs(dI_dp_prev), eps(T)) + denomq = max(abs(dI_dq), abs(dI_dq_prev), eps(T)) + rI = abs(Ixpqn - Ixpq) / denomI + rp = abs(dI_dp - dI_dp_prev) / denomp + rq = abs(dI_dq - dI_dq_prev) / denomq + if max(rI, rp, rq) < ϵ + break + end + end + Ixpq = Ixpqn + dI_dp_prev = dI_dp + dI_dq_prev = dI_dq + + # Shift CF state for next iteration + App = Ap + Bpp = Bp + Ap = An + Bp = Bn + dApp_dp = dAp_dp + dApp_dq = dAp_dq + dBpp_dp = dBp_dp + dBpp_dq = dBp_dq + dAp_dp = dAn_dp + dAp_dq = dAn_dq + dBp_dp = dBn_dp + dBp_dq = dBn_dq + end + + # 7) Undo tail-swap if applied; ∂I/∂x is the pdf at original (a,b,x) + if swap + return oneT - Ixpqn, -dI_dq, -dI_dp, dx + else + return Ixpqn, dI_dp, dI_dq, dx + end +end + + + + + +# Incomplete beta: beta_inc(a,b,x) -> (p, q) with q=1-p +function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, b::Number, x::Number) + # primal + p, q = beta_inc(a, b, x) + # derivatives + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔxT::T = Δx isa Real ? T(Δx) : zero(T) + Δp = dIa * ΔaT + dIb * ΔbT + dIx * ΔxT + Δq = -Δp + Tout = typeof((p, q)) + return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) +end + +function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number) + p, q = beta_inc(a, b, x) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tx = ChainRulesCore.ProjectTo(x) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + function beta_inc_pullback(Δ) + Δp, Δq = Δ + s = Δp - Δq # because q = 1 - p + ā = Ta(s * dIa) + b̄ = Tb(s * dIb) + x̄ = Tx(s * dIx) + return ChainRulesCore.NoTangent(), ā, b̄, x̄ + end + return (p, q), beta_inc_pullback +end +function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) + p, q = beta_inc(a, b, x, y) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔxT::T = Δx isa Real ? T(Δx) : zero(T) + ΔyT::T = Δy isa Real ? T(Δy) : zero(T) + Δp = dIa * ΔaT + dIb * ΔbT + dIx * (ΔxT - ΔyT) + Δq = -Δp + Tout = typeof((p, q)) + return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) +end + +function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) + p, q = beta_inc(a, b, x, y) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tx = ChainRulesCore.ProjectTo(x) + Ty = ChainRulesCore.ProjectTo(y) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + function beta_inc_pullback(Δ) + Δp, Δq = Δ + s = Δp - Δq + ā = Ta(s * dIa) + b̄ = Tb(s * dIb) + x̄ = Tx(s * dIx) + ȳ = Ty(-s * dIx) + return ChainRulesCore.NoTangent(), ā, b̄, x̄, ȳ + end + return (p, q), beta_inc_pullback +end + +# Inverse incomplete beta: beta_inc_inv(a,b,p) -> (x, 1-x) +function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Number, b::Number, p::Number) + x, y = beta_inc_inv(a, b, p) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) + # Implicit differentiation at solved x: I_x(a,b) = p + _, dIa_, dIb_, _ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_ + # ∂I/∂x at solved x via stable log-space expression + dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx::T = inv(dIx_acc) + dx_da::T = -dIa * inv_dIx + dx_db::T = -dIb * inv_dIx + dx_dp::T = inv_dIx + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔpT::T = Δp isa Real ? T(Δp) : zero(T) + Δx = dx_da * ΔaT + dx_db * ΔbT + dx_dp * ΔpT + Δy = -Δx + Tout = typeof((x, y)) + return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy) +end + +function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::Number) + x, y = beta_inc_inv(a, b, p) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tp = ChainRulesCore.ProjectTo(p) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) + _, dIa_, dIb_, _ = _beta_inc_grad(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_ + # ∂I/∂x at solved x via stable log-space expression + dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx::T = inv(dIx_acc) + dx_da::T = -dIa * inv_dIx + dx_db::T = -dIb * inv_dIx + dx_dp::T = inv_dIx + function beta_inc_inv_pullback(Δ) + Δx, Δy = Δ + s = Δx - Δy + ā = Ta(s * dx_da) + b̄ = Tb(s * dx_db) + p̄ = Tp(s * dx_dp) + return ChainRulesCore.NoTangent(), ā, b̄, p̄ + end + return (x, y), beta_inc_inv_pullback +end + end # module diff --git a/test/chainrules.jl b/test/chainrules.jl index 1754d591..1b012976 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -176,4 +176,152 @@ _, x̄ = back(1f0) @test x̄ isa Float32 end -end + + @testset "beta_inc and beta_inc_inv" begin + @testset "beta_inc and beta_inc_inv minimal (no-FD identities)" begin + a = 1.2 + b = 2.3 + x = 0.4 + # Direct derivative checks without FD: ∂I/∂x equals beta pdf + pdf = x^(a - 1) * (1 - x)^(b - 1) / beta(a, b) + _, Δx = frule((NoTangent(), 0.0, 0.0, 1.0), beta_inc, a, b, x) + @test isapprox(Δx[1], pdf; rtol=1e-12, atol=1e-12) + + # Symmetry check: ∂I/∂a(a,b,x) = -∂I/∂b(b,a,1-x) + _, Δa = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc, a, b, x) + _, Δb_sw = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc, b, a, 1 - x) + @test isapprox(Δa[1], -Δb_sw[1]; rtol=1e-10, atol=1e-12) + + # Composition identity f(g(p)) = p: forward-mode differential equals 1 for dp, 0 for da,db + p = first(beta_inc(a, b, x)) + x_inv, _ = beta_inc_inv(a, b, p) + # Check primal composition + p_roundtrip = first(beta_inc(a, b, x_inv)) + @test isapprox(p_roundtrip, p; rtol=1e-12, atol=1e-12) + # Forward through g then f: dp + _, Δx_inv_dp = frule((NoTangent(), 0.0, 0.0, 1.0), beta_inc_inv, a, b, p) + _, Δp_from_dp = frule((NoTangent(), 0.0, 0.0, Δx_inv_dp[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_dp[1], 1.0; rtol=1e-9, atol=1e-12) + # Forward da + _, Δx_inv_da = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc_inv, a, b, p) + _, Δp_from_da = frule((NoTangent(), 1.0, 0.0, Δx_inv_da[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_da[1], 0.0; rtol=1e-9, atol=1e-12) + # Forward db + _, Δx_inv_db = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc_inv, a, b, p) + _, Δp_from_db = frule((NoTangent(), 0.0, 1.0, Δx_inv_db[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_db[1], 0.0; rtol=1e-9, atol=1e-12) + + # Reverse-mode chain for composition: pullback through f then g + # Pullback of f at (a,b,x_inv) + _, pb_f = rrule(beta_inc, a, b, x_inv) + _, āf, b̄f, x̄f = pb_f((1.0, 0.0)) + # Pullback of g at (a,b,p) with cotangent x̄f for x + _, pb_g = rrule(beta_inc_inv, a, b, p) + _, āg, b̄g, p̄g = pb_g((x̄f, 0.0)) + ā_total = āf + āg + b̄_total = b̄f + b̄g + p̄_total = p̄g + @test isapprox(ā_total, 0.0; rtol=1e-10, atol=1e-12) + @test isapprox(b̄_total, 0.0; rtol=1e-10, atol=1e-12) + @test isapprox(p̄_total, 1.0; rtol=1e-9, atol=1e-12) + end + + @testset "incomplete beta: basic test_frule/test_rrule" begin + # Use an expanded set of interior points (avoid endpoints for FD) to exercise many branches: + # Rationale for x values: + # - Include values around 0.1, 0.3, 0.5, 0.7, 0.9 to trigger different code paths. + # - Include 0.14 and 0.28 to straddle the bx ≤ 0.7 power-series threshold for b ≈ 5 and 2.5. + # - Include values near 0.5 (0.49, 0.51) to probe near-symmetry and tail swaps. + # - Include additional midpoints to increase chance that x ≈ a/(a+b) for some (a,b), which makes λ ≈ 0 + # in the large-parameter regime (key for choosing symmetric asymptotics when min(a,b) > 100). + # - Add a few more around 0.6–0.8 to exercise continued fraction vs. asymptotics for large (a,b). + test_points = ( + 0.05, 0.08, 0.10, 0.12, 0.14, 0.18, 0.20, 0.22, 0.26, + 0.28, 0.30, 0.32, 0.35, 0.38, 0.40, 0.42, 0.45, + 0.49, 0.50, 0.51, 0.55, 0.58, 0.60, 0.62, 0.65, + 0.68, 0.70, 0.72, 0.76, 0.80, 0.85, 0.90 + ) + # Rationale for a,b values: + # - <1: 0.4, 0.6 to stress small-parameter power series branches. + # - Near 1: 0.9, 1.1 to test branch boundaries and continuity across a≈1, b≈1. + # - Moderate: 2.5, 5.0 where multiple algorithm choices engage based on x and bx. + # - Large (≥15, ≥40) to drive large-parameter regimes: 16.0, 45.0. + # - Very large (≫100): 100.5, 150.0 to ensure symmetric vs asymmetric asymptotics are exercised when λ + # is small/large, and continued fractions are robust for large shapes. + ab = (0.4, 0.6, 0.9, 1.1, 2.5, 5.0, 16.0, 45.0, 100.5, 150.0) + + # 3-argument beta_inc(a,b,x) + for a in ab, b in ab, x in test_points + 0.0 < x < 1.0 || continue + test_frule(beta_inc, a, b, x) + test_rrule(beta_inc, a, b, x) + end + + # Inverse beta: beta_inc_inv(a,b,p) + for a in ab, b in ab, p in test_points + 0.0 < p < 1.0 || continue + test_frule(beta_inc_inv, a, b, p) + test_rrule(beta_inc_inv, a, b, p) + end + + # Float32 promotion sanity (lightweight) + a32 = 1.5f0; b32 = 2.25f0; x32 = 0.3f0 + # Finite-difference checks for Float32 are noisier; use looser tolerances + test_frule(beta_inc, a32, b32, x32; rtol=5e-4, atol=1e-6) + test_rrule(beta_inc, a32, b32, x32; rtol=5e-4, atol=1e-6) + p32 = first(beta_inc(a32, b32, x32)) + test_frule(beta_inc_inv, a32, b32, p32; rtol=5e-4, atol=1e-6) + test_rrule(beta_inc_inv, a32, b32, p32; rtol=5e-4, atol=1e-6) + end + + @testset "4-arg beta_inc identities (y = 1 - x)" begin + # Exercise more regimes while keeping y = 1 - x constraint. + # Same rationale as above for x and (a,b) coverage. + test_points = ( + 0.05, 0.10, 0.12, 0.14, 0.20, 0.28, 0.35, 0.40, 0.49, 0.50, 0.51, 0.60, 0.65, 0.70, 0.72, 0.80, 0.90 + ) + ab = (0.4, 0.6, 0.9, 1.1, 2.5, 5.0, 16.0, 45.0, 100.5, 150.0) + + for a in ab, b in ab, x in test_points + 0.0 < x < 1.0 || continue + y = 1 - x + # Primal consistency: 4-arg matches 3-arg when y = 1 - x + p3, q3 = beta_inc(a, b, x) + p4, q4 = beta_inc(a, b, x, y) + @test isapprox(p4, p3; rtol=1e-12, atol=1e-12) + @test isapprox(q4, q3; rtol=1e-12, atol=1e-12) + + # Analytical pdf + pdf = x^(a - 1) * (1 - x)^(b - 1) / beta(a, b) + + # Constrained x-variation: dx = 1, dy = -1 => dp = 2 * pdf, dq = -dp + _, Δxy = frule((NoTangent(), 0.0, 0.0, 1.0, -1.0), beta_inc, a, b, x, y) + @test isapprox(Δxy[1], 2 * pdf; rtol=1e-11, atol=1e-12) + @test isapprox(Δxy[2], -Δxy[1]; rtol=1e-11, atol=1e-12) + + # Parameter derivatives should match 3-arg ones + _, Δa3 = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc, a, b, x) + _, Δb3 = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc, a, b, x) + _, Δa4 = frule((NoTangent(), 1.0, 0.0, 0.0, 0.0), beta_inc, a, b, x, y) + _, Δb4 = frule((NoTangent(), 0.0, 1.0, 0.0, 0.0), beta_inc, a, b, x, y) + @test isapprox(Δa4[1], Δa3[1]; rtol=1e-11, atol=1e-12) + @test isapprox(Δb4[1], Δb3[1]; rtol=1e-11, atol=1e-12) + + # Reverse-mode: compare pullbacks for 3-arg vs constrained 4-arg + _, pb3 = rrule(beta_inc, a, b, x) + _, ā3, b̄3, x̄3 = pb3((1.0, 0.0)) + _, pb4 = rrule(beta_inc, a, b, x, y) + _, ā4, b̄4, x̄4, ȳ4 = pb4((1.0, 0.0)) + @test isapprox(ā4, ā3; rtol=1e-11, atol=1e-12) + @test isapprox(b̄4, b̄3; rtol=1e-11, atol=1e-12) + # Unconstrained pullbacks should satisfy x̄4 ≈ x̄3 and ȳ4 ≈ -x̄3 + @test isapprox(x̄4, x̄3; rtol=1e-11, atol=1e-12) + @test isapprox(ȳ4, -x̄3; rtol=1e-11, atol=1e-12) + # Effective pullback along the constraint y = 1 - x equals 2*x̄3 + x̄_eff = x̄4 - ȳ4 + @test isapprox(x̄_eff, 2 * x̄3; rtol=1e-11, atol=1e-12) + end + end + + end +end \ No newline at end of file