Skip to content
33 changes: 31 additions & 2 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,49 @@ end

# sampling

function _rand_handle_overflow!(
rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real}
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the style consistent with the surrounding code:

Suggested change
function _rand_handle_overflow!(
rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real}
)
function _rand_handle_overflow!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})

Σ = sum(x)
if Σ == 0.0
# Distribution behavior approaches categorical as Σα -> 0
α = d.alpha
iΣα = inv(d.alpha0)
if isinf(iΣα)
# Dirichlet with ALL deeply subnormal parameters
α .*= floatmax(eltype(α))
iΣα = inv(sum(α))
end
x[rand(rng, Categorical(iΣα .* α))] = 1
return x
end

iΣ = inv(Σ)
if isinf(iΣ)
# Σ is deep subnormal
x .*= floatmax(eltype(x))
iΣ = inv(sum(x))
end

lmul!(iΣ, x) # this returns x
end

function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
end
lmul!(inv(sum(x)), x) # this returns x
_rand_handle_overflow!(rng, d, x)
end

function _rand!(rng::AbstractRNG,
d::Dirichlet{T,<:FillArrays.AbstractFill{T}},
x::AbstractVector{<:Real}) where {T<:Real}
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
lmul!(inv(sum(x)), x) # this returns x
_rand_handle_overflow!(rng, d, x)
end

#######################################
Expand Down
22 changes: 22 additions & 0 deletions test/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,25 @@ end
end
end
end

@testset "Dirichlet rand Inf and NaN (#1702)" begin
for d in [
Dirichlet([8e-5, 1e-5, 2e-5]),
Dirichlet([8e-4, 1e-4, 2e-4]),
Dirichlet([4.5e-5, 8e-5]),
Dirichlet([6e-5, 2e-5, 3e-5, 4e-5, 5e-5]),
Dirichlet(FillArrays.Fill(1e-5, 5))
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ mean(d) atol=0.01
@test var(x, dims = 2) ≈ var(d) atol=0.01
end

for (d, μ) in [ # Subnormal params cause mean(d) to error
(Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]),
(Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4])
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ μ atol=0.01
end
end
Loading