Skip to content

Fully implement DistributionMeasure and add var transforms #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 90 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
16af94c
STASH vartransform
oschulz Jun 15, 2022
9225c1a
STASH require_insupport
oschulz Jun 15, 2022
6ea552b
STASH
oschulz Jun 15, 2022
21326f2
Require MeasureBase v0.10
oschulz Jun 15, 2022
fb82354
Add ChainRulesCore, ForwardDiff and ForwardDiffPullbacks
oschulz Jun 15, 2022
d365698
STASH
oschulz Jun 15, 2022
5db5f62
FIXUP deps
oschulz Jun 15, 2022
ea9a6d6
Add Statistics, StatsBase and StatsFuns to deps
oschulz Jun 15, 2022
0f37407
STASH
oschulz Jun 15, 2022
b537945
STASH
oschulz Jun 15, 2022
f8aafab
STASH
oschulz Jun 15, 2022
7627b05
Add IrrationalConstants and FillArrays to deps
oschulz Jun 15, 2022
806c201
Use Distributions pdf, logpdf, etc.
oschulz Jun 15, 2022
148e819
Implement MeasureBase interface for distributions
oschulz Jun 15, 2022
60c42df
STASH
oschulz Jun 15, 2022
8e94e01
Use Random.rand!
oschulz Jun 15, 2022
adf70b3
STASH
oschulz Jun 15, 2022
349fd32
FIXUP no using pdf, logpdf, etc.
oschulz Jun 15, 2022
77313fb
Use FillArrays Ones and Zeros
oschulz Jun 15, 2022
c43ece2
Add PDMats to deps
oschulz Jun 15, 2022
cef3ce8
STASH stddist
oschulz Jun 15, 2022
aef8533
STASH
oschulz Jun 15, 2022
3629814
Using LinearAlgebra.Diagonal
oschulz Jun 15, 2022
b8ad6b5
STASH
oschulz Jun 15, 2022
57e450a
STASH
oschulz Jun 15, 2022
4d2854d
STASH
oschulz Jun 15, 2022
91a40c6
STASH
oschulz Jun 15, 2022
29d8d2a
STASH
oschulz Jun 15, 2022
1d61c95
STASH
oschulz Jun 15, 2022
612de3c
Add dependencies
oschulz Jun 15, 2022
d639ad2
STASH
oschulz Jun 15, 2022
3f8e342
FIXUP deps
oschulz Jun 15, 2022
0539b0c
FIXUP deps
oschulz Jun 15, 2022
2c019fc
STASH
oschulz Jun 15, 2022
e76abe2
STASH
oschulz Jun 15, 2022
7762601
STASH
oschulz Jun 15, 2022
c58384e
FIXUP test ad utils
oschulz Jun 15, 2022
2d74e87
Use test Project.toml
oschulz Jun 15, 2022
9307bb2
STASH test vtrafo
oschulz Jun 15, 2022
9e7b9e4
FIXUP test Project.toml
oschulz Jun 15, 2022
aa42bdf
STASH
oschulz Jun 15, 2022
3e82bf8
STASH vartransform_def
oschulz Jun 15, 2022
79ce816
FIXUP deps
oschulz Jun 15, 2022
dd6ab04
STASH trafotests
oschulz Jun 15, 2022
0e15881
STASH
oschulz Jun 16, 2022
39035ec
STASH deps
oschulz Jun 16, 2022
0f40e4d
STASH
oschulz Jun 16, 2022
4f26564
Add Static to deps
oschulz Jun 16, 2022
11a763b
MeasureBase deps
oschulz Jun 16, 2022
19cc23f
Add StdNormal measure
oschulz Jun 16, 2022
3bc9040
Fixup StdNormal measure
oschulz Jun 16, 2022
0848251
vartransform for StdNormal
oschulz Jun 16, 2022
53ce9a5
FIXUP deps
oschulz Jun 16, 2022
ab5fc63
STASH transforms
oschulz Jun 16, 2022
6b339b3
Export StdNormal
oschulz Jun 16, 2022
7cea2d7
FIXUP deps
oschulz Jun 16, 2022
a876c50
Fix vartransform for StdNormal
oschulz Jun 16, 2022
3b0ed28
FIXUP getjacobian in tests
oschulz Jun 16, 2022
9395827
STASH vartransform_def for StdNormal
oschulz Jun 16, 2022
60ed6ce
STASH
oschulz Jun 17, 2022
a675872
STASH measure interface for dists
oschulz Jun 17, 2022
d6814fd
FIXUP
oschulz Jun 17, 2022
aa92f70
STASH
oschulz Jun 17, 2022
c5edd8b
STASH a lot
oschulz Jun 18, 2022
5495b9d
Test vartransform auto dim-sel
oschulz Jun 18, 2022
21b1bdf
STASH trafos
oschulz Jun 18, 2022
50d184d
STASH vartransform
oschulz Jun 18, 2022
6010987
STASH vartransform tests and dirichlet
oschulz Jun 18, 2022
5e72697
STASH vartrafo tests
oschulz Jun 18, 2022
fb45ddc
STASH transform tests
oschulz Jun 18, 2022
d2b7306
Fix transform var naming
oschulz Jun 18, 2022
8852754
Rename vartransform to transport_to
oschulz Jun 19, 2022
e4f3cdb
Rename vartransform_origin
oschulz Jun 19, 2022
ea77b3b
Require MeasureBase v0.11
oschulz Jun 19, 2022
3474573
Add MeasureBase to test deps
oschulz Jun 19, 2022
184bc00
Fix transport tests
oschulz Jun 19, 2022
8c48244
fix some typos
cscherrer Jun 20, 2022
40757e2
Fixes and more tests
oschulz Jun 20, 2022
2c3b691
Fixes and tests
oschulz Jun 20, 2022
f68e214
STASH tests
oschulz Jun 20, 2022
30c746c
STASH
oschulz Jun 20, 2022
81a6627
FIX Prefix scale
oschulz Jun 20, 2022
e130d25
Fix transport_def for StandardDist
oschulz Jun 20, 2022
2e433c3
Fixes
oschulz Jun 20, 2022
9153836
Fix tests
oschulz Jun 20, 2022
a673c26
Adapt to MeasureBase v0.12
oschulz Jun 20, 2022
58a791c
Allow for static-sized StandardDist
oschulz Jun 20, 2022
4b22c6b
Fix product transport
oschulz Jun 21, 2022
40077dd
Increase package version to v0.2.0
oschulz Jun 21, 2022
319f97d
Support RN integral notation for distributions
oschulz Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
name = "DistributionMeasures"
uuid = "35643b39-bfd4-4670-843f-16596ca89bf3"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ForwardDiffPullbacks = "450a3b6d-2448-4ee1-8e34-e4eb8713b605"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ArgCheck = "1, 2"
ArraysOfArrays = "0.5"
ChainRulesCore = "1"
ChangesOfVariables = "0.1"
DensityInterface = "0.4"
Distributions = "0.25"
FillArrays = "0.12, 0.13"
ForwardDiff = "0.9, 0.10"
ForwardDiffPullbacks = "0.2"
Functors = "0.2"
InverseFunctions = "0.1"
MeasureBase = "0.9"
IrrationalConstants = "0.1"
MeasureBase = "0.12"
PDMats = "0.11"
Static = "0.5, 0.6"
StatsBase = "0.32, 0.33"
StatsFuns = "0.9, 1"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
60 changes: 51 additions & 9 deletions src/DistributionMeasures.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,75 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

module DistributionMeasures

using LinearAlgebra: Diagonal, dot, cholesky

import Random
using Random: AbstractRNG
using Random: AbstractRNG, rand!

import DensityInterface
using DensityInterface: logdensityof

import MeasureBase
using MeasureBase: AbstractMeasure, Lebesgue, Counting
using MeasureBase: PowerMeasure
using MeasureBase: AbstractMeasure, Lebesgue, Counting, ℝ
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic
using MeasureBase: PowerMeasure, WeightedMeasure
using MeasureBase: basemeasure, testvalue
using MeasureBase: getdof, checked_arg
using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin
using MeasureBase: NoTransformOrigin, NoTransport

import Distributions
using Distributions: Distribution, VariateForm, ValueSupport
using Distributions: ArrayLikeVariate, Continuous, Discrete
using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution
using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete
using Distributions: Uniform, Exponential, Logistic, Normal
using Distributions: MvNormal, Beta, Dirichlet
using Distributions: ReshapedDistribution

import Statistics
import StatsBase
import StatsFuns
import PDMats

using IrrationalConstants: log2π, invsqrt2π

using Static: True, False, StaticInt, static
using FillArrays: Fill, Ones, Zeros

import ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk

import ForwardDiff
using ForwardDiffPullbacks: fwddiff

import Functors
using Functors: fmap

using ArgCheck: @argcheck

using ArraysOfArrays: ArrayOfSimilarArrays, flatview

const MeasureLike = Union{AbstractMeasure,Distribution}
export MeasureLike

include("utils.jl")
include("autodiff_utils.jl")
include("measure_interface.jl")
include("stdnormal_measure.jl")
include("standard_dist.jl")
include("standard_uniform.jl")
include("standard_normal.jl")
include("distribution_measure.jl")
include("dist_vartransform.jl")
include("univariate.jl")
include("standardmv.jl")
include("product.jl")
include("reshaped.jl")
include("dirichlet.jl")


const MeasureLike = Union{AbstractMeasure,Distribution}

export MeasureLike, DistributionMeasure
export StdNormal
export DistributionMeasure
export StandardDist


end # module
75 changes: 75 additions & 0 deletions src/autodiff_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

@inline _adignore_call(f) = f()
@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent())
ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback

macro _adignore(expr)
:(_adignore_call(() -> $(esc(expr))))
end


function _pushfront(v::AbstractVector, x)
T = promote_type(eltype(v), typeof(x))
r = similar(v, T, length(eachindex(v)) + 1)
r[firstindex(r)] = x
r[firstindex(r)+1:lastindex(r)] = v
r
end

function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
result = _pushfront(v, x)
function _pushfront_pullback(thunked_ΔΩ)
ΔΩ = unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
end
return result, _pushfront_pullback
end


function _pushback(v::AbstractVector, x)
T = promote_type(eltype(v), typeof(x))
r = similar(v, T, length(eachindex(v)) + 1)
r[lastindex(r)] = x
r[firstindex(r):lastindex(r)-1] = v
r
end

function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
result = _pushback(v, x)
function _pushback_pullback(thunked_ΔΩ)
ΔΩ = unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
end
return result, _pushback_pullback
end


_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)]

_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1]


_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs)))

function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
result = _rev_cumsum(xs)
function _rev_cumsum_pullback(ΔΩ)
∂xs = @thunk cumsum(unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _rev_cumsum_pullback
end


# Equivalent to `cumprod(xs)``:
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs)))

function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
result = _exp_cumsum_log(xs)
function _exp_cumsum_log_pullback(ΔΩ)
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _exp_cumsum_log_pullback
end
30 changes: 30 additions & 0 deletions src/dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

MeasureBase.getdof(d::Dirichlet) = length(d) - 1

MeasureBase.transport_origin(ν::Dirichlet) = StdUniform()^getdof(ν)


function _dirichlet_beta_trafo(α::Real, β::Real, x::Real)
R = float(promote_type(typeof(α), typeof(β), typeof(x)))
convert(R, transport_def(Beta(α, β), StdUniform(), x))::R
end

_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b)

function MeasureBase.from_origin(ν::Dirichlet, x)
# See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution",
# https://arxiv.org/abs/1010.3436

# Sanity check (TODO - remove?):
@_adignore @argcheck length(ν) == length(x) + 1

αs = _dropfront(_rev_cumsum(ν.alpha))
βs = _dropback(ν.alpha)
beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x)
beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1))
beta_v_ext = _pushback(beta_v, 0)
fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext)
end

# ToDo: MeasureBase.to_origin(ν::Dirichlet, y)
16 changes: 16 additions & 0 deletions src/dist_vartransform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

const _AnyStdUniform = Union{StandardUniform, Uniform}
const _AnyStdNormal = Union{StandardNormal, Normal}

const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal}

_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform
_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M()
_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof)
_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ))

MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ)
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν))
47 changes: 23 additions & 24 deletions src/distribution_measure.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).


"""
struct DistributionMeasure <: AbstractMeasure

Expand All @@ -14,52 +17,52 @@ struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}}
d::D
end


@inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d)

@inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d)

@inline Distributions.Distribution(m::DistributionMeasure) = m.distribution
@inline Distributions.Distribution(m::DistributionMeasure) = m.d
@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)

@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m)
@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)

@inline DensityInterface.densityof(m::DistributionMeasure) = DensityInterface.densityof(m.d)
@inline DensityInterface.densityof(m::DistributionMeasure, x) = DensityInterface.densityof(m.d, x)
@inline DensityInterface.logdensityof(m::DistributionMeasure) = DensityInterface.logdensityof(m.d)
@inline DensityInterface.logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)

@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)
@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x)

@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.x)
@inline DensityInterface.densityof(μ::DistributionMeasure) = DensityInterface.densityof(μ.d)
@inline DensityInterface.densityof(μ::DistributionMeasure, x) = DensityInterface.densityof(μ.d, x)
@inline DensityInterface.logdensityof(μ::DistributionMeasure) = DensityInterface.logdensityof(μ.d)
@inline DensityInterface.logdensityof(μ::DistributionMeasure, x) = DensityInterface.logdensityof(μ.d, x)

@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.d)
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.d)
@inline MeasureBase.logdensity_def(μ::DistributionMeasure, x) = MeasureBase.logdensity_def(μ.d, x)
@inline MeasureBase.unsafe_logdensityof(μ::DistributionMeasure, x) = MeasureBase.unsafe_logdensityof(μ.d, x)
@inline MeasureBase.insupport(μ::DistributionMeasure, x) = MeasureBase.insupport(μ.d, x)
@inline MeasureBase.basemeasure(μ::DistributionMeasure) = MeasureBase.basemeasure(μ.d)
@inline MeasureBase.paramnames(μ::DistributionMeasure) = MeasureBase.paramnames(μ.d)
@inline MeasureBase.params(μ::DistributionMeasure) = MeasureBase.params(μ.d)
@inline MeasureBase.transport_origin(ν::DistributionMeasure) = ν.d
@inline MeasureBase.to_origin(::DistributionMeasure, y) = y
@inline MeasureBase.from_origin(::DistributionMeasure, x) = x

@inline MeasureBase.rootmeasure(m::DistributionMeasure) = MeasureBase.basemeasure(m)


Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = _convert_numtype(T, rand(m.d))
Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.d))

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real}
_convert_numtype(T, reshape(rand(d, prod(sz)), sz...))
convert_realtype(T, reshape(rand(d, prod(sz)), sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real}
_convert_numtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
convert_realtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N}
_convert_numtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
convert_realtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N}
flatview(ArrayOfSimilarArrays(_convert_numtype(T, rand(d, sz))))
flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(d, sz))))
end

function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N}
Expand All @@ -70,7 +73,3 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMe
flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes))
ArrayOfSimilarArrays{T,M,N}(flat_data)
end


@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.d)
@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(m.d))
23 changes: 23 additions & 0 deletions src/measure_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

@inline MeasureBase.logdensity_def(d::Distribution, x) = DensityInterface.logdensityof(d, x)
@inline MeasureBase.unsafe_logdensityof(d::Distribution, x) = DensityInterface.logdensityof(d, x)

@inline MeasureBase.insupport(d::Distribution, x) = Distributions.insupport(d, x)

@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue()
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(d)
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Discrete}) = Counting()
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(d)

@inline MeasureBase.paramnames(d::Distribution) = propertynames(d)
@inline MeasureBase.params(d::Distribution) = NamedTuple{propertynames(d)}(Distributions.params(d))

@inline MeasureBase.testvalue(d::Distribution) = testvalue(basemeasure(d))


@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))
@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d)


MeasureBase.∫(f, base::Distribution) = MeasureBase.∫(f, convert(AbstractMeasure, base))
20 changes: 20 additions & 0 deletions src/product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

const _StdPowMeasure1 = PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}
const _UniformProductDist1x{D} = Distributions.Product{Continuous,D,<:AbstractVector{D}}


MeasureBase.getdof(d::_UniformProductDist1x) = length(d)


function _product_dist_trafo_impl(νs, μs, x)
fwddiff(transport_def).(νs, μs, x)
end

function MeasureBase.transport_def(ν::_StdPowMeasure1, μ::_UniformProductDist1x, x)
_product_dist_trafo_impl((ν.parent,), μ.v, x)
end

function MeasureBase.transport_def(ν::_UniformProductDist1x, μ::_StdPowMeasure1, x)
_product_dist_trafo_impl(ν.v, (μ.parent,), x)
end
Loading