Skip to content

Commit 35070be

Browse files
committed
Add DistributionMeasure
1 parent 72de7aa commit 35070be

File tree

4 files changed

+111
-5
lines changed

4 files changed

+111
-5
lines changed

LICENSE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
MIT License
22

3-
Copyright (c) 2022 Chad Scherrer <[email protected]> and contributors
3+
Copyright (c) 2022 Chad Scherrer <[email protected]>,
4+
Oliver Schulz <[email protected]> and contributors
45

56
Permission is hereby granted, free of charge, to any person obtaining a copy
67
of this software and associated documentation files (the "Software"), to deal

src/DistributionMeasures.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
module DistributionMeasures
22

3+
import Random
4+
using Random: AbstractRNG
5+
6+
import DensityInterface
7+
using DensityInterface: logdensityof
8+
9+
import MeasureBase
10+
using MeasureBase: AbstractMeasure, Lebesgue, Counting
11+
using MeasureBase: PowerMeasure
12+
13+
import Distributions
14+
using Distributions: Distribution, VariateForm, ValueSupport
15+
using Distributions: ArrayLikeVariate, Continuous, Discrete
16+
using Distributions: ReshapedDistribution
17+
18+
import Functors
19+
using Functors: fmap
20+
21+
using ArraysOfArrays: ArrayOfSimilarArrays, flatview
22+
23+
24+
include("utils.jl")
25+
include("distribution_measure.jl")
26+
27+
328
const MeasureLike = Union{AbstractMeasure,Distribution}
429

5-
struct DistributionMeasure{D<:Distribution} <: AbstractMeasure
6-
d::D
7-
end
30+
export MeasureLike, DistributionMeasure
31+
832

9-
end
33+
end # module

src/distribution_measure.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
struct DistributionMeasure <: AbstractMeasure
3+
4+
Wraps a `Distributions.Distribution` as a `MeasureBase.AbstractMeasure`.
5+
6+
Avoid calling `DistributionMeasure(d::Distribution)` directly. Instead, use
7+
`AbstractMeasure(d::Distribution)` to allow for specialized `Distribution`
8+
to `AbstractMeasure` conversions.
9+
10+
Use `convert(Distribution, m::DistributionMeasure)` or
11+
`Distribution(m::DistributionMeasure)` to convert back to a `Distribution`.
12+
"""
13+
struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} <: AbstractMeasure
14+
d::D
15+
end
16+
17+
@inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d)
18+
19+
@inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d)
20+
21+
@inline Distributions.Distribution(m::DistributionMeasure) = m.distribution
22+
@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
23+
@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
24+
25+
@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m)
26+
@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
27+
@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
28+
29+
@inline DensityInterface.densityof(m::DistributionMeasure) = DensityInterface.densityof(m.d)
30+
@inline DensityInterface.densityof(m::DistributionMeasure, x) = DensityInterface.densityof(m.d, x)
31+
@inline DensityInterface.logdensityof(m::DistributionMeasure) = DensityInterface.logdensityof(m.d)
32+
@inline DensityInterface.logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
33+
34+
@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
35+
@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
36+
37+
@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.x)
38+
39+
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
40+
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.d)
41+
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
42+
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.d)
43+
44+
@inline MeasureBase.rootmeasure(m::DistributionMeasure) = MeasureBase.basemeasure(m)
45+
46+
47+
Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = _convert_numtype(T, rand(m.d))
48+
49+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real}
50+
_convert_numtype(T, reshape(rand(d, prod(sz)), sz...))
51+
end
52+
53+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real}
54+
_convert_numtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
55+
end
56+
57+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N}
58+
_convert_numtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
59+
end
60+
61+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N}
62+
flatview(ArrayOfSimilarArrays(_convert_numtype(T, rand(d, sz))))
63+
end
64+
65+
function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N}
66+
_flat_powrand(rng, T, m.parent.d, map(length, m.axes))
67+
end
68+
69+
function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N}
70+
flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes))
71+
ArrayOfSimilarArrays{T,M,N}(flat_data)
72+
end
73+
74+
75+
@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.d)
76+
@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(m.d))

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@inline _convert_numtype(::Type{T}, x::T) where {T<:Real} = x
2+
@inline _convert_numtype(::Type{T}, x::AbstractArray{T}) where {T<:Real} = x
3+
@inline _convert_numtype(::Type{T}, x::U) where {T<:Real,U<:Real} = T(X)
4+
_convert_numtype(::Type{T}, x::AbstractArray{U}) where {T<:Real,U<:Real} = T.(x)
5+
_convert_numtype(::Type{T}, x) where {T<:Real} = fmap(elem -> _convert_numtype(T, elem), x)

0 commit comments

Comments
 (0)