Skip to content

Commit 03c5f17

Browse files
authored
Fully implement DistributionMeasure and add var transforms (#6)
1 parent 9b59ee1 commit 03c5f17

26 files changed

+1557
-47
lines changed

Project.toml

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,46 @@
11
name = "DistributionMeasures"
22
uuid = "35643b39-bfd4-4670-843f-16596ca89bf3"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
7+
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
78
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
810
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
911
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1012
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
13+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
14+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
ForwardDiffPullbacks = "450a3b6d-2448-4ee1-8e34-e4eb8713b605"
1116
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1217
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
18+
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1319
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1420
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
21+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1522
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
23+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
24+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
25+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
26+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1627

1728
[compat]
29+
ArgCheck = "1, 2"
1830
ArraysOfArrays = "0.5"
31+
ChainRulesCore = "1"
1932
ChangesOfVariables = "0.1"
2033
DensityInterface = "0.4"
2134
Distributions = "0.25"
35+
FillArrays = "0.12, 0.13"
36+
ForwardDiff = "0.9, 0.10"
37+
ForwardDiffPullbacks = "0.2"
2238
Functors = "0.2"
2339
InverseFunctions = "0.1"
24-
MeasureBase = "0.9"
40+
IrrationalConstants = "0.1"
41+
MeasureBase = "0.12"
42+
PDMats = "0.11"
43+
Static = "0.5, 0.6"
44+
StatsBase = "0.32, 0.33"
45+
StatsFuns = "0.9, 1"
2546
julia = "1.6"
26-
27-
[extras]
28-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
29-
30-
[targets]
31-
test = ["Test"]

src/DistributionMeasures.jl

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,75 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
13
module DistributionMeasures
24

5+
using LinearAlgebra: Diagonal, dot, cholesky
6+
37
import Random
4-
using Random: AbstractRNG
8+
using Random: AbstractRNG, rand!
59

610
import DensityInterface
711
using DensityInterface: logdensityof
812

913
import MeasureBase
10-
using MeasureBase: AbstractMeasure, Lebesgue, Counting
11-
using MeasureBase: PowerMeasure
14+
using MeasureBase: AbstractMeasure, Lebesgue, Counting, ℝ
15+
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic
16+
using MeasureBase: PowerMeasure, WeightedMeasure
17+
using MeasureBase: basemeasure, testvalue
18+
using MeasureBase: getdof, checked_arg
19+
using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin
20+
using MeasureBase: NoTransformOrigin, NoTransport
1221

1322
import Distributions
14-
using Distributions: Distribution, VariateForm, ValueSupport
15-
using Distributions: ArrayLikeVariate, Continuous, Discrete
23+
using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution
24+
using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete
25+
using Distributions: Uniform, Exponential, Logistic, Normal
26+
using Distributions: MvNormal, Beta, Dirichlet
1627
using Distributions: ReshapedDistribution
1728

29+
import Statistics
30+
import StatsBase
31+
import StatsFuns
32+
import PDMats
33+
34+
using IrrationalConstants: log2π, invsqrt2π
35+
36+
using Static: True, False, StaticInt, static
37+
using FillArrays: Fill, Ones, Zeros
38+
39+
import ChainRulesCore
40+
using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk
41+
42+
import ForwardDiff
43+
using ForwardDiffPullbacks: fwddiff
44+
1845
import Functors
1946
using Functors: fmap
2047

48+
using ArgCheck: @argcheck
49+
2150
using ArraysOfArrays: ArrayOfSimilarArrays, flatview
2251

52+
const MeasureLike = Union{AbstractMeasure,Distribution}
53+
export MeasureLike
2354

2455
include("utils.jl")
56+
include("autodiff_utils.jl")
57+
include("measure_interface.jl")
58+
include("stdnormal_measure.jl")
59+
include("standard_dist.jl")
60+
include("standard_uniform.jl")
61+
include("standard_normal.jl")
2562
include("distribution_measure.jl")
63+
include("dist_vartransform.jl")
64+
include("univariate.jl")
65+
include("standardmv.jl")
66+
include("product.jl")
67+
include("reshaped.jl")
68+
include("dirichlet.jl")
2669

27-
28-
const MeasureLike = Union{AbstractMeasure,Distribution}
29-
30-
export MeasureLike, DistributionMeasure
70+
export StdNormal
71+
export DistributionMeasure
72+
export StandardDist
3173

3274

3375
end # module

src/autodiff_utils.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
@inline _adignore_call(f) = f()
4+
@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent())
5+
ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback
6+
7+
macro _adignore(expr)
8+
:(_adignore_call(() -> $(esc(expr))))
9+
end
10+
11+
12+
function _pushfront(v::AbstractVector, x)
13+
T = promote_type(eltype(v), typeof(x))
14+
r = similar(v, T, length(eachindex(v)) + 1)
15+
r[firstindex(r)] = x
16+
r[firstindex(r)+1:lastindex(r)] = v
17+
r
18+
end
19+
20+
function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
21+
result = _pushfront(v, x)
22+
function _pushfront_pullback(thunked_ΔΩ)
23+
ΔΩ = unthunk(thunked_ΔΩ)
24+
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
25+
end
26+
return result, _pushfront_pullback
27+
end
28+
29+
30+
function _pushback(v::AbstractVector, x)
31+
T = promote_type(eltype(v), typeof(x))
32+
r = similar(v, T, length(eachindex(v)) + 1)
33+
r[lastindex(r)] = x
34+
r[firstindex(r):lastindex(r)-1] = v
35+
r
36+
end
37+
38+
function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
39+
result = _pushback(v, x)
40+
function _pushback_pullback(thunked_ΔΩ)
41+
ΔΩ = unthunk(thunked_ΔΩ)
42+
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
43+
end
44+
return result, _pushback_pullback
45+
end
46+
47+
48+
_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)]
49+
50+
_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1]
51+
52+
53+
_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs)))
54+
55+
function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
56+
result = _rev_cumsum(xs)
57+
function _rev_cumsum_pullback(ΔΩ)
58+
∂xs = @thunk cumsum(unthunk(ΔΩ))
59+
(NoTangent(), ∂xs)
60+
end
61+
return result, _rev_cumsum_pullback
62+
end
63+
64+
65+
# Equivalent to `cumprod(xs)``:
66+
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs)))
67+
68+
function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
69+
result = _exp_cumsum_log(xs)
70+
function _exp_cumsum_log_pullback(ΔΩ)
71+
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ))
72+
(NoTangent(), ∂xs)
73+
end
74+
return result, _exp_cumsum_log_pullback
75+
end

src/dirichlet.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
MeasureBase.getdof(d::Dirichlet) = length(d) - 1
4+
5+
MeasureBase.transport_origin::Dirichlet) = StdUniform()^getdof(ν)
6+
7+
8+
function _dirichlet_beta_trafo::Real, β::Real, x::Real)
9+
R = float(promote_type(typeof(α), typeof(β), typeof(x)))
10+
convert(R, transport_def(Beta(α, β), StdUniform(), x))::R
11+
end
12+
13+
_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b)
14+
15+
function MeasureBase.from_origin::Dirichlet, x)
16+
# See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution",
17+
# https://arxiv.org/abs/1010.3436
18+
19+
# Sanity check (TODO - remove?):
20+
@_adignore @argcheck length(ν) == length(x) + 1
21+
22+
αs = _dropfront(_rev_cumsum.alpha))
23+
βs = _dropback.alpha)
24+
beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x)
25+
beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1))
26+
beta_v_ext = _pushback(beta_v, 0)
27+
fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext)
28+
end
29+
30+
# ToDo: MeasureBase.to_origin(ν::Dirichlet, y)

src/dist_vartransform.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
const _AnyStdUniform = Union{StandardUniform, Uniform}
4+
const _AnyStdNormal = Union{StandardNormal, Normal}
5+
6+
const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal}
7+
8+
_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform
9+
_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal
10+
11+
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M()
12+
_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof)
13+
_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ))
14+
15+
MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ)
16+
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν))

src/distribution_measure.jl

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
14
"""
25
struct DistributionMeasure <: AbstractMeasure
36
@@ -14,52 +17,52 @@ struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}}
1417
d::D
1518
end
1619

20+
1721
@inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d)
1822

1923
@inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d)
2024

21-
@inline Distributions.Distribution(m::DistributionMeasure) = m.distribution
25+
@inline Distributions.Distribution(m::DistributionMeasure) = m.d
2226
@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
2327
@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
2428

2529
@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m)
2630
@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
2731
@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
2832

29-
@inline DensityInterface.densityof(m::DistributionMeasure) = DensityInterface.densityof(m.d)
30-
@inline DensityInterface.densityof(m::DistributionMeasure, x) = DensityInterface.densityof(m.d, x)
31-
@inline DensityInterface.logdensityof(m::DistributionMeasure) = DensityInterface.logdensityof(m.d)
32-
@inline DensityInterface.logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
33-
34-
@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
35-
@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
3633

37-
@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.x)
34+
@inline DensityInterface.densityof::DistributionMeasure) = DensityInterface.densityof.d)
35+
@inline DensityInterface.densityof::DistributionMeasure, x) = DensityInterface.densityof.d, x)
36+
@inline DensityInterface.logdensityof::DistributionMeasure) = DensityInterface.logdensityof.d)
37+
@inline DensityInterface.logdensityof::DistributionMeasure, x) = DensityInterface.logdensityof.d, x)
3838

39-
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
40-
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.d)
41-
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
42-
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.d)
39+
@inline MeasureBase.logdensity_def::DistributionMeasure, x) = MeasureBase.logdensity_def.d, x)
40+
@inline MeasureBase.unsafe_logdensityof::DistributionMeasure, x) = MeasureBase.unsafe_logdensityof.d, x)
41+
@inline MeasureBase.insupport::DistributionMeasure, x) = MeasureBase.insupport.d, x)
42+
@inline MeasureBase.basemeasure::DistributionMeasure) = MeasureBase.basemeasure.d)
43+
@inline MeasureBase.paramnames::DistributionMeasure) = MeasureBase.paramnames.d)
44+
@inline MeasureBase.params::DistributionMeasure) = MeasureBase.params.d)
45+
@inline MeasureBase.transport_origin::DistributionMeasure) = ν.d
46+
@inline MeasureBase.to_origin(::DistributionMeasure, y) = y
47+
@inline MeasureBase.from_origin(::DistributionMeasure, x) = x
4348

44-
@inline MeasureBase.rootmeasure(m::DistributionMeasure) = MeasureBase.basemeasure(m)
4549

46-
47-
Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = _convert_numtype(T, rand(m.d))
50+
Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.d))
4851

4952
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real}
50-
_convert_numtype(T, reshape(rand(d, prod(sz)), sz...))
53+
convert_realtype(T, reshape(rand(d, prod(sz)), sz...))
5154
end
5255

5356
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real}
54-
_convert_numtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
57+
convert_realtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
5558
end
5659

5760
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N}
58-
_convert_numtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
61+
convert_realtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
5962
end
6063

6164
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N}
62-
flatview(ArrayOfSimilarArrays(_convert_numtype(T, rand(d, sz))))
65+
flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(d, sz))))
6366
end
6467

6568
function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N}
@@ -70,7 +73,3 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMe
7073
flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes))
7174
ArrayOfSimilarArrays{T,M,N}(flat_data)
7275
end
73-
74-
75-
@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.d)
76-
@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(m.d))

src/measure_interface.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
@inline MeasureBase.logdensity_def(d::Distribution, x) = DensityInterface.logdensityof(d, x)
4+
@inline MeasureBase.unsafe_logdensityof(d::Distribution, x) = DensityInterface.logdensityof(d, x)
5+
6+
@inline MeasureBase.insupport(d::Distribution, x) = Distributions.insupport(d, x)
7+
8+
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
9+
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(d)
10+
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
11+
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(d)
12+
13+
@inline MeasureBase.paramnames(d::Distribution) = propertynames(d)
14+
@inline MeasureBase.params(d::Distribution) = NamedTuple{propertynames(d)}(Distributions.params(d))
15+
16+
@inline MeasureBase.testvalue(d::Distribution) = testvalue(basemeasure(d))
17+
18+
19+
@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))
20+
@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d)
21+
22+
23+
MeasureBase.(f, base::Distribution) = MeasureBase.(f, convert(AbstractMeasure, base))

src/product.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).
2+
3+
const _StdPowMeasure1 = PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}
4+
const _UniformProductDist1x{D} = Distributions.Product{Continuous,D,<:AbstractVector{D}}
5+
6+
7+
MeasureBase.getdof(d::_UniformProductDist1x) = length(d)
8+
9+
10+
function _product_dist_trafo_impl(νs, μs, x)
11+
fwddiff(transport_def).(νs, μs, x)
12+
end
13+
14+
function MeasureBase.transport_def::_StdPowMeasure1, μ::_UniformProductDist1x, x)
15+
_product_dist_trafo_impl((ν.parent,), μ.v, x)
16+
end
17+
18+
function MeasureBase.transport_def::_UniformProductDist1x, μ::_StdPowMeasure1, x)
19+
_product_dist_trafo_impl.v, (μ.parent,), x)
20+
end

0 commit comments

Comments
 (0)