Skip to content

Commit 3135e1e

Browse files
committed
STASH Distributions ext impl
1 parent 04e24b0 commit 3135e1e

25 files changed

+1720
-3
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ Compat = "3.35, 4"
4747
ConstructionBase = "1.3"
4848
DensityInterface = "0.4"
4949
Distributions = "0.25.1"
50-
Distributions = "0.25.111"
5150
FillArrays = "0.12, 0.13, 1"
5251
ForwardDiff = "0.8, 0.9, 0.10"
5352
FunctionChains = "0.1"

ext/MeasureBaseDistributionsExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
module MeasureBaseDistributionsExt
44

5-
using MeasureBase
6-
import Distributions
5+
include "distributions/distributions.jl"
76

87
end # module MeasureBaseDistributionsExt

ext/distributions/_bat_dist_transforms.jl

Lines changed: 477 additions & 0 deletions
Large diffs are not rendered by default.

ext/distributions/autodiff_utils.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
@inline _adignore_call(f) = f()
4+
@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent())
5+
ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback
6+
7+
macro _adignore(expr)
8+
:(_adignore_call(() -> $(esc(expr))))
9+
end
10+
11+
12+
function _pushfront(v::AbstractVector, x)
13+
T = promote_type(eltype(v), typeof(x))
14+
r = similar(v, T, length(eachindex(v)) + 1)
15+
r[firstindex(r)] = x
16+
r[firstindex(r)+1:lastindex(r)] = v
17+
r
18+
end
19+
20+
function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
21+
result = _pushfront(v, x)
22+
function _pushfront_pullback(thunked_ΔΩ)
23+
ΔΩ = unthunk(thunked_ΔΩ)
24+
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
25+
end
26+
return result, _pushfront_pullback
27+
end
28+
29+
30+
function _pushback(v::AbstractVector, x)
31+
T = promote_type(eltype(v), typeof(x))
32+
r = similar(v, T, length(eachindex(v)) + 1)
33+
r[lastindex(r)] = x
34+
r[firstindex(r):lastindex(r)-1] = v
35+
r
36+
end
37+
38+
function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
39+
result = _pushback(v, x)
40+
function _pushback_pullback(thunked_ΔΩ)
41+
ΔΩ = unthunk(thunked_ΔΩ)
42+
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
43+
end
44+
return result, _pushback_pullback
45+
end
46+
47+
48+
_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)]
49+
50+
_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1]
51+
52+
53+
_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs)))
54+
55+
function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
56+
result = _rev_cumsum(xs)
57+
function _rev_cumsum_pullback(ΔΩ)
58+
∂xs = @thunk cumsum(unthunk(ΔΩ))
59+
(NoTangent(), ∂xs)
60+
end
61+
return result, _rev_cumsum_pullback
62+
end
63+
64+
65+
# Equivalent to `cumprod(xs)``:
66+
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs)))
67+
68+
function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
69+
result = _exp_cumsum_log(xs)
70+
function _exp_cumsum_log_pullback(ΔΩ)
71+
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ))
72+
(NoTangent(), ∂xs)
73+
end
74+
return result, _exp_cumsum_log_pullback
75+
end

ext/distributions/dirac.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
MeasureBase.AbstractMeasure(obj::Distributions.Dirac) = MeasureBase.Dirac(obj.value)
4+
5+
function AsMeasure{D}(::D) where {D<:Distributions.Dirac}
6+
throw(ArgumentError("Don't wrap Distributions.Dirac into MeasureBase.AsMeasure, use asmeasure to convert instead."))
7+
end
8+
9+
10+
Distributions.Distribution(m::MeasureBase.Dirac{<:Real}) = Distribtions.Dirac(m.x)
11+
12+
function Distributions.Distribution(@nospecialize(m::MeasureBase.Dirac{T})) where T
13+
throw(ArgumentError("Can only convert MeasureBase.Dirac{<:Real} to Distributions.Dirac, but not MeasureBase.Dirac{<:$(nameof(T))}"))
14+
end

ext/distributions/dirichlet.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
const DirichletMeasure = AsMeasure{<:Dirichlet}
4+
5+
MeasureBase.getdof(m::DirichletMeasure) = length(m.obj) - 1
6+
7+
MeasureBase.transport_origin(m::DirichletMeasure) = StdUniform()^getdof(m)
8+
9+
10+
11+
function _dirichlet_beta_trafo::Real, β::Real, x::Real)
12+
R = float(promote_type(typeof(α), typeof(β), typeof(x)))
13+
convert(R, transport_def(Beta(α, β), StdUniform(), x))::R
14+
end
15+
16+
_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b)
17+
18+
function MeasureBase.from_origin::Dirichlet, x)
19+
# See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution",
20+
# https://arxiv.org/abs/1010.3436
21+
22+
# Sanity check (TODO - remove?):
23+
@_adignore @argcheck length(ν) == length(x) + 1
24+
25+
αs = _dropfront(_rev_cumsum.alpha))
26+
βs = _dropback.alpha)
27+
beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x)
28+
beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1))
29+
beta_v_ext = _pushback(beta_v, 0)
30+
fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext)
31+
end
32+
33+
# ToDo: MeasureBase.to_origin(ν::Dirichlet, y)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
const _AnyStdUniform = Union{StandardUniform, Uniform}
4+
const _AnyStdNormal = Union{StandardNormal, Normal}
5+
6+
const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal}
7+
8+
_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform
9+
_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal
10+
11+
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M()
12+
_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof)
13+
_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ))
14+
15+
MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ)
16+
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
4+
const DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} = AsMeasure{D}
5+
6+
@inline MeasureBase.AbstractMeasure(obj::Distribution) = AsMeasure{typeof(obj)}(obj)
7+
@inline Base.convert(::Type{AbstractMeasure}, obj::Distribution) = AbstractMeasure(obj)
8+
9+
@inline Distributions.Distribution(m::DistributionMeasure) = m.obj
10+
@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
11+
@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
12+
13+
@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m)
14+
@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
15+
@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
16+
17+
18+
Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.obj))
19+
20+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real}
21+
convert_realtype(T, reshape(rand(d, prod(sz)), sz...))
22+
end
23+
24+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real}
25+
convert_realtype(T, reshape(rand(rng, d, prod(sz)), size(d)..., sz...))
26+
end
27+
28+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N}
29+
convert_realtype(T, reshape(rand(rng, d.dist, prod(sz)), d.dims..., sz...))
30+
end
31+
32+
function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real}
33+
flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(rng, d, sz))))
34+
end
35+
36+
function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N}
37+
_flat_powrand(rng, T, m.parent.obj, map(length, m.axes))
38+
end
39+
40+
function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N}
41+
flat_data = _flat_powrand(rng, T, m.parent.obj, map(length, m.axes))
42+
ArrayOfSimilarArrays{T,M,N}(flat_data)
43+
end
44+
45+
46+
@inline DensityInterface.densityof(m::DistributionMeasure) = densityof(m.obj)
47+
@inline DensityInterface.logdensityof(m::DistributionMeasure) = logdensityof(m.obj)
48+
49+
@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x)
50+
@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x)
51+
@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.obj, x)
52+
53+
@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
54+
@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.obj)
55+
@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
56+
@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.obj)
57+
58+
@inline MeasureBase.basemeasure(m::DistributionMeasure) = rootmeasure(m)
59+
60+
@inline MeasureBase.mspace_elsize(m::DistributionMeasure{<:ArrayLikeVariate}) = size(m.obj)
61+
62+
@inline MeasureBase.getdof(m::DistributionMeasure{<:ArrayLikeVariate{0}}) = 1
63+
64+
@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.obj)
65+
@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{propertynames(m.obj)}(Distributions.params(m.obj))
66+
67+
# @inline MeasureBase.testvalue(m::DistributionMeasure) = testvalue(basemeasure(d))
68+
69+
70+
@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))
71+
@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d)

ext/distributions/distributions.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
using LinearAlgebra: Diagonal, dot, cholesky
4+
5+
import Random
6+
using Random: AbstractRNG, rand!
7+
8+
import DensityInterface
9+
using DensityInterface: logdensityof
10+
11+
import MeasureBase
12+
using MeasureBase: AbstractMeasure, AsMeasure
13+
using MeasureBase: Lebesgue, Counting, ℝ
14+
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic
15+
using MeasureBase: PowerMeasure, WeightedMeasure
16+
using MeasureBase: basemeasure, testvalue
17+
using MeasureBase: getdof, checked_arg
18+
using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin
19+
using MeasureBase: NoTransformOrigin, NoTransport
20+
21+
import Distributions
22+
using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution
23+
using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete
24+
using Distributions: Uniform, Exponential, Logistic, Normal
25+
using Distributions: MvNormal, Beta, Dirichlet
26+
using Distributions: ReshapedDistribution
27+
28+
import Statistics
29+
import StatsBase
30+
import StatsFuns
31+
import PDMats
32+
33+
using IrrationalConstants: log2π, invsqrt2π
34+
35+
using Static: True, False, StaticInt, static
36+
using FillArrays: Fill, Ones, Zeros
37+
38+
import ChainRulesCore
39+
using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk
40+
41+
import ForwardDiff
42+
using ForwardDiffPullbacks: fwddiff
43+
44+
import Functors
45+
using Functors: fmap
46+
47+
using ArgCheck: @argcheck
48+
49+
using ArraysOfArrays: ArrayOfSimilarArrays, flatview
50+
51+
include("utils.jl")
52+
include("autodiff_utils.jl")
53+
include("standard_dist.jl")
54+
include("standard_uniform.jl")
55+
include("standard_normal.jl")
56+
include("distribution_measure.jl")
57+
include("dist_vartransform.jl")
58+
include("univariate.jl")
59+
include("standardmv.jl")
60+
include("product.jl")
61+
include("reshaped.jl")
62+
include("dirichlet.jl")
63+
64+
export StdNormal
65+
export DistributionMeasure
66+
export StandardDist

ext/distributions/mixture.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
# ToDo:
4+
# AbstractMixtureModel: MixtureModel, UnivariateGMM

0 commit comments

Comments
 (0)