From a9b1790008b6482748b38636add862af1d1488ab Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 17 Sep 2025 00:20:35 +0200 Subject: [PATCH] Add sufficient statistics and MLE for `Chi` and `Chisq` --- Project.toml | 2 +- docs/src/fit.md | 2 + src/univariate/continuous/chi.jl | 21 +++++++++- src/univariate/continuous/chisq.jl | 20 +++++++++- test/fit.jl | 62 ++++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 033930d7cc..5adbe274c7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.120" +version = "0.25.121" [deps] AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" diff --git a/docs/src/fit.md b/docs/src/fit.md index b482f995e6..36a67799cd 100644 --- a/docs/src/fit.md +++ b/docs/src/fit.md @@ -43,6 +43,8 @@ The `fit_mle` method has been implemented for the following distributions: - [`Beta`](@ref) - [`Binomial`](@ref) - [`Categorical`](@ref) +- [`Chi`](@ref) +- [`Chisq`](@ref) - [`DiscreteUniform`](@ref) - [`Exponential`](@ref) - [`LogNormal`](@ref) diff --git a/src/univariate/continuous/chi.jl b/src/univariate/continuous/chi.jl index 8fc7310038..0cdbdfbc00 100644 --- a/src/univariate/continuous/chi.jl +++ b/src/univariate/continuous/chi.jl @@ -23,7 +23,7 @@ External links """ struct Chi{T<:Real} <: ContinuousUnivariateDistribution ν::T - Chi{T}(ν::T) where {T} = new{T}(ν) + Chi{T}(ν::Real) where {T<:Real} = new{T}(ν) end function Chi(ν::Real; check_args::Bool=true) @@ -119,3 +119,22 @@ end rand(rng::AbstractRNG, s::ChiSampler) = sqrt(rand(rng, s.s)) sampler(d::Chi) = ChiSampler(sampler(Chisq(d.ν))) + + +#### Fitting + +struct ChiStats{T<:Real} <: SufficientStats + # (Weighted) mean of log(x) + mlogx::T +end + +suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}) = ChiStats(mean(log, x)) +function suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}) + if axes(x) != axes(w) + throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.")) + end + mlogx = sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))) / sum(w) + return ChiStats(mlogx) +end + +fit_mle(::Type{T}, ss::ChiStats) where {T<:Chi} = T(2 * invdigamma(2 * ss.mlogx - logtwo)) diff --git a/src/univariate/continuous/chisq.jl b/src/univariate/continuous/chisq.jl index 8604742ce5..a7180c9200 100644 --- a/src/univariate/continuous/chisq.jl +++ b/src/univariate/continuous/chisq.jl @@ -22,7 +22,7 @@ External links """ struct Chisq{T<:Real} <: ContinuousUnivariateDistribution ν::T - Chisq{T}(ν::T) where {T} = new{T}(ν) + Chisq{T}(ν::Real) where {T<:Real} = new{T}(ν) end function Chisq(ν::Real; check_args::Bool=true) @@ -107,3 +107,21 @@ function sampler(d::Chisq) θ = oftype(α, 2) return sampler(Gamma{typeof(α)}(α, θ)) end + +#### Fitting + +struct ChisqStats{T<:Real} <: SufficientStats + # (Weighted) mean of log(x) + mlogx::T +end + +suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}) = ChisqStats(mean(log, x)) +function suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}) + if axes(x) != axes(w) + throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.")) + end + mlogx = sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))) / sum(w) + return ChisqStats(mlogx) +end + +fit_mle(::Type{T}, ss::ChisqStats) where {T<:Chisq} = T(2 * invdigamma(ss.mlogx - logtwo)) diff --git a/test/fit.jl b/test/fit.jl index 143f0b6ee4..317a60b669 100644 --- a/test/fit.jl +++ b/test/fit.jl @@ -6,6 +6,7 @@ using Distributions using OffsetArrays +using ForwardDiff using Test, Random, LinearAlgebra @@ -465,3 +466,64 @@ end end end + +@testset "Testing fit for Chi" begin + ν = 3.1 + for func in funcs, D in (Chi, Chi{Float64}, Chi{Float32}) + v = func[1](n0) + z = func[2](D(ν), n0) + for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2)) + ss = @inferred suffstats(D, x) + @test ss isa Distributions.ChiStats + @test ss.mlogx ≈ mean(log.(x)) + + d = @inferred fit(D, x) + @test d isa D + @test ForwardDiff.derivative(ν -> sum(logpdf.(Chi(ν), x)), dof(d)) ≈ 0 atol = (eps(partype(d)))^(2/3) + + if axes(x) == axes(w) + d = @inferred fit(D, x, w) + @test d isa D + @test ForwardDiff.derivative(ν -> dot(logpdf.(Chi(ν), x), w), dof(d)) ≈ 0 atol = (eps(partype(d)))^(2/3) + + ss = @inferred suffstats(D, x, w) + @test ss isa Distributions.ChiStats + @test ss.mlogx ≈ dot(w ./ sum(w), log.(x)) + else + @test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w) + @test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w) + end + end + end +end + + +@testset "Testing fit for Chisq" begin + ν = 4.3 + for func in funcs, D in (Chisq, Chisq{Float64}, Chisq{Float32}) + v = func[1](n0) + z = func[2](D(ν), n0) + for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2)) + ss = @inferred suffstats(D, x) + @test ss isa Distributions.ChisqStats + @test ss.mlogx ≈ mean(log.(x)) + + d = @inferred fit(D, x) + @test d isa D + @test ForwardDiff.derivative(ν -> sum(logpdf.(Chisq(ν), x)), dof(d)) ≈ 0 atol = (eps(partype(d)))^(2/3) + + if axes(x) == axes(w) + ss = @inferred suffstats(D, x, w) + @test ss isa Distributions.ChisqStats + @test ss.mlogx ≈ dot(w ./ sum(w), log.(x)) + + d = @inferred fit(D, x, w) + @test d isa D + @test ForwardDiff.derivative(ν -> dot(logpdf.(Chisq(ν), x), w), dof(d)) ≈ 0 atol = (eps(partype(d)))^(2/3) + else + @test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w) + @test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w) + end + end + end +end