Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,22 @@ rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) =
rand(rng::AbstractRNG, d::MixtureModel{Univariate}) =
rand(rng, component(d, rand(rng, d.prior)))

function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int)
counts = rand(rng, Multinomial(n, probs(d.prior)))
x = Vector{eltype(d)}(undef, n)
offset = 0
for i in eachindex(counts)
ni = counts[i]
if ni > 0
c = component(d, i)
v = view(x, (offset+1):(offset+ni))
v .= rand(rng, c, ni)
offset += ni
end
end
return shuffle!(rng, x)
end

# multivariate mixture sampler for a vector
_rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector{<:Real}) =
@inbounds rand!(rng, s.csamplers[rand(rng, s.psampler)], x)
Expand Down
53 changes: 53 additions & 0 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ end

## random number generation


"""
rand(rng::AbstractRNG, d::Truncated)

Generate a single random sample from a truncated distribution.

The sampling strategy depends on the probability mass of the truncated region (`tp`):
- If `tp > 0.25`, rejection sampling is used. This is efficient when the truncated region covers a large portion of the original distribution.
- If `sqrt(eps) < tp <= 0.25`, inverse transform sampling is used. This is more efficient for smaller truncated regions.
- If `tp` is very small (`<= sqrt(eps)`), a numerically stable version of inverse transform sampling is used which performs calculations in log-space to maintain precision.
"""
function rand(rng::AbstractRNG, d::Truncated)
d0 = d.untruncated
tp = d.tp
Expand All @@ -233,6 +244,48 @@ function rand(rng::AbstractRNG, d::Truncated)
end
end


"""
rand(rng::AbstractRNG, d::Truncated, n::Int)

Generate `n` random samples from a truncated distribution.

The implementation samples the untruncated distribution, `d0` with `rand(rng, d0, n)` in batches and only keeps the samples that fall within the truncated range. The size of the batches is adaptively estimated to reduce the number of iterations.

See [rand(rng::AbstractRNG, d::Truncated)](@ref) that handles the case of small mass of the truncated region.

!!! warning
This method can be inefficient if the probability mass of the truncated region is very small.
"""
function rand(rng::AbstractRNG, d::Truncated, n::Int)
n == 0 && return eltype(d)[]
#
d0 = d.untruncated
tp = d.tp
lower = d.lower
upper = d.upper
# Preallocate samples array
samples = Vector{eltype(d)}(undef, n)
n_collected = 0
while n_collected < n
n_remaining = n - n_collected
# Estimate number of samples to draw from the untruncated distribution.
# We draw more to reduce the chance of needing more rounds.
n_expected = n_remaining / tp
δn_expected = sqrt(n_remaining * tp * (1 - tp)) # standard deviation of the expected number
n_batch = ceil(Int, n_expected + 3δn_expected)
samples_d0 = rand(rng, d0, n_batch)
for s in samples_d0
if _in_closed_interval(s, lower, upper)
n_collected += 1
samples[n_collected] = s
n_collected == n && break
end
end
end
return samples
end

## show

function show(io::IO, d::Truncated)
Expand Down
Loading