|
| 1 | +""" |
| 2 | + struct DistributionMeasure <: AbstractMeasure |
| 3 | +
|
| 4 | +Wraps a `Distributions.Distribution` as a `MeasureBase.AbstractMeasure`. |
| 5 | +
|
| 6 | +Avoid calling `DistributionMeasure(d::Distribution)` directly. Instead, use |
| 7 | +`AbstractMeasure(d::Distribution)` to allow for specialized `Distribution` |
| 8 | +to `AbstractMeasure` conversions. |
| 9 | +
|
| 10 | +Use `convert(Distribution, m::DistributionMeasure)` or |
| 11 | +`Distribution(m::DistributionMeasure)` to convert back to a `Distribution`. |
| 12 | +""" |
| 13 | +struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} <: AbstractMeasure |
| 14 | + d::D |
| 15 | +end |
| 16 | + |
| 17 | +@inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d) |
| 18 | + |
| 19 | +@inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d) |
| 20 | + |
| 21 | +@inline Distributions.Distribution(m::DistributionMeasure) = m.distribution |
| 22 | +@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) |
| 23 | +@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) |
| 24 | + |
| 25 | +@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m) |
| 26 | +@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) |
| 27 | +@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) |
| 28 | + |
| 29 | +@inline DensityInterface.densityof(m::DistributionMeasure) = DensityInterface.densityof(m.d) |
| 30 | +@inline DensityInterface.densityof(m::DistributionMeasure, x) = DensityInterface.densityof(m.d, x) |
| 31 | +@inline DensityInterface.logdensityof(m::DistributionMeasure) = DensityInterface.logdensityof(m.d) |
| 32 | +@inline DensityInterface.logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) |
| 33 | + |
| 34 | +@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) |
| 35 | +@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) |
| 36 | + |
| 37 | +@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.x) |
| 38 | + |
| 39 | +@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() |
| 40 | +@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.d) |
| 41 | +@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting() |
| 42 | +@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.d) |
| 43 | + |
| 44 | +@inline MeasureBase.rootmeasure(m::DistributionMeasure) = MeasureBase.basemeasure(m) |
| 45 | + |
| 46 | + |
| 47 | +Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = _convert_numtype(T, rand(m.d)) |
| 48 | + |
| 49 | +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real} |
| 50 | + _convert_numtype(T, reshape(rand(d, prod(sz)), sz...)) |
| 51 | +end |
| 52 | + |
| 53 | +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real} |
| 54 | + _convert_numtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...)) |
| 55 | +end |
| 56 | + |
| 57 | +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N} |
| 58 | + _convert_numtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...)) |
| 59 | +end |
| 60 | + |
| 61 | +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N} |
| 62 | + flatview(ArrayOfSimilarArrays(_convert_numtype(T, rand(d, sz)))) |
| 63 | +end |
| 64 | + |
| 65 | +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N} |
| 66 | + _flat_powrand(rng, T, m.parent.d, map(length, m.axes)) |
| 67 | +end |
| 68 | + |
| 69 | +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N} |
| 70 | + flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes)) |
| 71 | + ArrayOfSimilarArrays{T,M,N}(flat_data) |
| 72 | +end |
| 73 | + |
| 74 | + |
| 75 | +@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.d) |
| 76 | +@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(m.d)) |
0 commit comments