Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MeasureTheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ include("parameterized/binomial.jl")
include("parameterized/multinomial.jl")
include("parameterized/lkj-cholesky.jl")
include("parameterized/negativebinomial.jl")
include("parameterized/betabinomial.jl")
include("parameterized/gamma.jl")
include("parameterized/snedecorf.jl")
include("parameterized/inverse-gaussian.jl")
Expand Down
44 changes: 44 additions & 0 deletions src/parameterized/betabinomial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Beta-Binomial distribution

export BetaBinomial
import Base
using SpecialFunctions

@parameterized BetaBinomial(n, α, β)

basemeasure(d::BetaBinomial) = CountingMeasure()

testvalue(::BetaBinomial) = 0

@kwstruct BetaBinomial(n, α, β)

function Base.rand(rng::AbstractRNG, ::Type, d::BetaBinomial{(:n, :α, :β)})
rand(rng, Dists.BetaBinomial(d.n, d.α, d.β))
end

function Base.rand(
rng::AbstractRNG,
::Type,
d::BetaBinomial{(:n, :α, :β),Tuple{I,A}},
) where {I<:Integer,A}
rand(rng, Dists.BetaBinomial(d.n, d.α, d.β))
end

@inline function insupport(d::BetaBinomial, x)
isinteger(x) && 0 ≤ x ≤ d.n
end

@inline function logdensity_def(d::BetaBinomial{(:n, :α, :β)}, y)
(n, α, β) = (d.n, d.α, d.β)
logbinom = -log1p(n) - logbeta(y + 1, n - y + 1)
lognum = logbeta(y + α, n - y + β)
logdenom = logbeta(α, β)
return logbinom + lognum - logdenom
end

asparams(::Type{<:BetaBinomial}, ::StaticSymbol{:α}) = asℝ₊
asparams(::Type{<:BetaBinomial}, ::StaticSymbol{:β}) = asℝ₊

function proxy(d::BetaBinomial{(:n, :α, :β),Tuple{I,A}}) where {I<:Integer,A}
Dists.BetaBinomial(d.n, d.α, d.β)
end
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test_measures = Any[
Bernoulli(0.2)
Beta(2, 3)
Binomial(10, 0.3)
BetaBinomial(10, 2, 3)
Cauchy()
Dirichlet(ones(3))
Exponential()
Expand Down Expand Up @@ -277,7 +278,7 @@ end

@testset "Product of Diracs" begin
x = randn(3)
t = as(productmeasure(Dirac.(x)))
t = as(productmeasure(Dirac.(x)))
@test transform(t, []) == x
end

Expand All @@ -297,7 +298,7 @@ end

# chain = Chain(kernel, μ)

# dyniterate(iter::TimeLift, ::Nothing) = dyniterate(iter, 0=>nothing)
# dyniterate(iter::TimeLift, ::Nothing) = dyniterate(iter, 0=>nothing)
# tr1 = trace(TimeLift(chain), nothing, u -> u[1] > 15)
# tr2 = trace(TimeLift(rand(Random.GLOBAL_RNG, chain)), nothing, u -> u[1] > 15)
# collect(Iterators.take(chain, 10))
Expand Down Expand Up @@ -348,8 +349,8 @@ end
# NOTE: The `test_broken` below are mostly because of the change to `Affine`.
# For example, `Normal{(:μ,:σ)}` is now `Affine{(:μ,:σ), Normal{()}}`.
# The problem is not really with these measures, but with the tests
# themselves.
#
# themselves.
#
# We should instead probably be doing e.g.
# `D = typeof(Normal(μ=0.3, σ=4.1))`

Expand All @@ -371,6 +372,10 @@ end
@test repro(Beta, (:α, :β))
end

@testset "BetaBinomial" begin
@test repro(BetaBinomial, (:n, :α, :β), (n = 10,))
end

@testset "Cauchy" begin
@test_broken repro(Cauchy, (:μ, :σ))
end
Expand Down Expand Up @@ -652,7 +657,7 @@ end
end

x = rand(d)

@test logdensityof(d, x) isa Real
end

Expand Down