diff --git a/Project.toml b/Project.toml index 6f01fc7..92edac9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,31 +1,46 @@ name = "DistributionMeasures" uuid = "35643b39-bfd4-4670-843f-16596ca89bf3" authors = ["Chad Scherrer and contributors"] -version = "0.1.0" +version = "0.2.0" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ForwardDiffPullbacks = "450a3b6d-2448-4ee1-8e34-e4eb8713b605" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] +ArgCheck = "1, 2" ArraysOfArrays = "0.5" +ChainRulesCore = "1" ChangesOfVariables = "0.1" DensityInterface = "0.4" Distributions = "0.25" +FillArrays = "0.12, 0.13" +ForwardDiff = "0.9, 0.10" +ForwardDiffPullbacks = "0.2" Functors = "0.2" InverseFunctions = "0.1" -MeasureBase = "0.9" +IrrationalConstants = "0.1" +MeasureBase = "0.12" +PDMats = "0.11" +Static = "0.5, 0.6" +StatsBase = "0.32, 0.33" +StatsFuns = "0.9, 1" julia = "1.6" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/DistributionMeasures.jl b/src/DistributionMeasures.jl index 87bffca..2c7582e 100644 --- a/src/DistributionMeasures.jl +++ b/src/DistributionMeasures.jl @@ -1,33 +1,75 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + module DistributionMeasures +using LinearAlgebra: Diagonal, dot, cholesky + import Random -using Random: AbstractRNG +using Random: AbstractRNG, rand! import DensityInterface using DensityInterface: logdensityof import MeasureBase -using MeasureBase: AbstractMeasure, Lebesgue, Counting -using MeasureBase: PowerMeasure +using MeasureBase: AbstractMeasure, Lebesgue, Counting, ℝ +using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic +using MeasureBase: PowerMeasure, WeightedMeasure +using MeasureBase: basemeasure, testvalue +using MeasureBase: getdof, checked_arg +using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin +using MeasureBase: NoTransformOrigin, NoTransport import Distributions -using Distributions: Distribution, VariateForm, ValueSupport -using Distributions: ArrayLikeVariate, Continuous, Discrete +using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution +using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete +using Distributions: Uniform, Exponential, Logistic, Normal +using Distributions: MvNormal, Beta, Dirichlet using Distributions: ReshapedDistribution +import Statistics +import StatsBase +import StatsFuns +import PDMats + +using IrrationalConstants: log2π, invsqrt2π + +using Static: True, False, StaticInt, static +using FillArrays: Fill, Ones, Zeros + +import ChainRulesCore +using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk + +import ForwardDiff +using ForwardDiffPullbacks: fwddiff + import Functors using Functors: fmap +using ArgCheck: @argcheck + using ArraysOfArrays: ArrayOfSimilarArrays, flatview +const MeasureLike = Union{AbstractMeasure,Distribution} +export MeasureLike include("utils.jl") +include("autodiff_utils.jl") +include("measure_interface.jl") +include("stdnormal_measure.jl") +include("standard_dist.jl") +include("standard_uniform.jl") +include("standard_normal.jl") include("distribution_measure.jl") +include("dist_vartransform.jl") +include("univariate.jl") +include("standardmv.jl") +include("product.jl") +include("reshaped.jl") +include("dirichlet.jl") - -const MeasureLike = Union{AbstractMeasure,Distribution} - -export MeasureLike, DistributionMeasure +export StdNormal +export DistributionMeasure +export StandardDist end # module diff --git a/src/autodiff_utils.jl b/src/autodiff_utils.jl new file mode 100644 index 0000000..44e752d --- /dev/null +++ b/src/autodiff_utils.jl @@ -0,0 +1,75 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +@inline _adignore_call(f) = f() +@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent()) +ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback + +macro _adignore(expr) + :(_adignore_call(() -> $(esc(expr)))) +end + + +function _pushfront(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[firstindex(r)] = x + r[firstindex(r)+1:lastindex(r)] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x) + result = _pushfront(v, x) + function _pushfront_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)]) + end + return result, _pushfront_pullback +end + + +function _pushback(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[lastindex(r)] = x + r[firstindex(r):lastindex(r)-1] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x) + result = _pushback(v, x) + function _pushback_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)]) + end + return result, _pushback_pullback +end + + +_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] + +_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] + + +_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) + +function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector) + result = _rev_cumsum(xs) + function _rev_cumsum_pullback(ΔΩ) + ∂xs = @thunk cumsum(unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _rev_cumsum_pullback +end + + +# Equivalent to `cumprod(xs)``: +_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) + +function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector) + result = _exp_cumsum_log(xs) + function _exp_cumsum_log_pullback(ΔΩ) + ∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _exp_cumsum_log_pullback +end diff --git a/src/dirichlet.jl b/src/dirichlet.jl new file mode 100644 index 0000000..b26de58 --- /dev/null +++ b/src/dirichlet.jl @@ -0,0 +1,30 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +MeasureBase.getdof(d::Dirichlet) = length(d) - 1 + +MeasureBase.transport_origin(ν::Dirichlet) = StdUniform()^getdof(ν) + + +function _dirichlet_beta_trafo(α::Real, β::Real, x::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(x))) + convert(R, transport_def(Beta(α, β), StdUniform(), x))::R +end + +_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) + +function MeasureBase.from_origin(ν::Dirichlet, x) + # See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", + # https://arxiv.org/abs/1010.3436 + + # Sanity check (TODO - remove?): + @_adignore @argcheck length(ν) == length(x) + 1 + + αs = _dropfront(_rev_cumsum(ν.alpha)) + βs = _dropback(ν.alpha) + beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x) + beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) + beta_v_ext = _pushback(beta_v, 0) + fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) +end + +# ToDo: MeasureBase.to_origin(ν::Dirichlet, y) diff --git a/src/dist_vartransform.jl b/src/dist_vartransform.jl new file mode 100644 index 0000000..b4adeb8 --- /dev/null +++ b/src/dist_vartransform.jl @@ -0,0 +1,16 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +const _AnyStdUniform = Union{StandardUniform, Uniform} +const _AnyStdNormal = Union{StandardNormal, Normal} + +const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal} + +_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform +_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal + +_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M() +_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof) +_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ)) + +MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ) +MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν)) diff --git a/src/distribution_measure.jl b/src/distribution_measure.jl index 8e35864..5c010a6 100644 --- a/src/distribution_measure.jl +++ b/src/distribution_measure.jl @@ -1,3 +1,6 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + + """ struct DistributionMeasure <: AbstractMeasure @@ -14,11 +17,12 @@ struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} d::D end + @inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d) @inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d) -@inline Distributions.Distribution(m::DistributionMeasure) = m.distribution +@inline Distributions.Distribution(m::DistributionMeasure) = m.d @inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) @inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) @@ -26,40 +30,39 @@ end @inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) @inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) -@inline DensityInterface.densityof(m::DistributionMeasure) = DensityInterface.densityof(m.d) -@inline DensityInterface.densityof(m::DistributionMeasure, x) = DensityInterface.densityof(m.d, x) -@inline DensityInterface.logdensityof(m::DistributionMeasure) = DensityInterface.logdensityof(m.d) -@inline DensityInterface.logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) - -@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) -@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.d, x) -@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.x) +@inline DensityInterface.densityof(μ::DistributionMeasure) = DensityInterface.densityof(μ.d) +@inline DensityInterface.densityof(μ::DistributionMeasure, x) = DensityInterface.densityof(μ.d, x) +@inline DensityInterface.logdensityof(μ::DistributionMeasure) = DensityInterface.logdensityof(μ.d) +@inline DensityInterface.logdensityof(μ::DistributionMeasure, x) = DensityInterface.logdensityof(μ.d, x) -@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() -@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.d) -@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting() -@inline MeasureBase.basemeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.d) +@inline MeasureBase.logdensity_def(μ::DistributionMeasure, x) = MeasureBase.logdensity_def(μ.d, x) +@inline MeasureBase.unsafe_logdensityof(μ::DistributionMeasure, x) = MeasureBase.unsafe_logdensityof(μ.d, x) +@inline MeasureBase.insupport(μ::DistributionMeasure, x) = MeasureBase.insupport(μ.d, x) +@inline MeasureBase.basemeasure(μ::DistributionMeasure) = MeasureBase.basemeasure(μ.d) +@inline MeasureBase.paramnames(μ::DistributionMeasure) = MeasureBase.paramnames(μ.d) +@inline MeasureBase.params(μ::DistributionMeasure) = MeasureBase.params(μ.d) +@inline MeasureBase.transport_origin(ν::DistributionMeasure) = ν.d +@inline MeasureBase.to_origin(::DistributionMeasure, y) = y +@inline MeasureBase.from_origin(::DistributionMeasure, x) = x -@inline MeasureBase.rootmeasure(m::DistributionMeasure) = MeasureBase.basemeasure(m) - -Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = _convert_numtype(T, rand(m.d)) +Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.d)) function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real} - _convert_numtype(T, reshape(rand(d, prod(sz)), sz...)) + convert_realtype(T, reshape(rand(d, prod(sz)), sz...)) end function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real} - _convert_numtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...)) + convert_realtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...)) end function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N} - _convert_numtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...)) + convert_realtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...)) end function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N} - flatview(ArrayOfSimilarArrays(_convert_numtype(T, rand(d, sz)))) + flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(d, sz)))) end function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N} @@ -70,7 +73,3 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMe flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes)) ArrayOfSimilarArrays{T,M,N}(flat_data) end - - -@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.d) -@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(m.d)) diff --git a/src/measure_interface.jl b/src/measure_interface.jl new file mode 100644 index 0000000..7e6a884 --- /dev/null +++ b/src/measure_interface.jl @@ -0,0 +1,23 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +@inline MeasureBase.logdensity_def(d::Distribution, x) = DensityInterface.logdensityof(d, x) +@inline MeasureBase.unsafe_logdensityof(d::Distribution, x) = DensityInterface.logdensityof(d, x) + +@inline MeasureBase.insupport(d::Distribution, x) = Distributions.insupport(d, x) + +@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() +@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(d) +@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Discrete}) = Counting() +@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(d) + +@inline MeasureBase.paramnames(d::Distribution) = propertynames(d) +@inline MeasureBase.params(d::Distribution) = NamedTuple{propertynames(d)}(Distributions.params(d)) + +@inline MeasureBase.testvalue(d::Distribution) = testvalue(basemeasure(d)) + + +@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf))) +@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d) + + +MeasureBase.∫(f, base::Distribution) = MeasureBase.∫(f, convert(AbstractMeasure, base)) diff --git a/src/product.jl b/src/product.jl new file mode 100644 index 0000000..1c2d0cc --- /dev/null +++ b/src/product.jl @@ -0,0 +1,20 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +const _StdPowMeasure1 = PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}} +const _UniformProductDist1x{D} = Distributions.Product{Continuous,D,<:AbstractVector{D}} + + +MeasureBase.getdof(d::_UniformProductDist1x) = length(d) + + +function _product_dist_trafo_impl(νs, μs, x) + fwddiff(transport_def).(νs, μs, x) +end + +function MeasureBase.transport_def(ν::_StdPowMeasure1, μ::_UniformProductDist1x, x) + _product_dist_trafo_impl((ν.parent,), μ.v, x) +end + +function MeasureBase.transport_def(ν::_UniformProductDist1x, μ::_StdPowMeasure1, x) + _product_dist_trafo_impl(ν.v, (μ.parent,), x) +end diff --git a/src/reshaped.jl b/src/reshaped.jl new file mode 100644 index 0000000..b366aa5 --- /dev/null +++ b/src/reshaped.jl @@ -0,0 +1,9 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +MeasureBase.getdof(μ::ReshapedDistribution) = MeasureBase.getdof(μ.dist) + +MeasureBase.transport_origin(μ::ReshapedDistribution) = μ.dist + +MeasureBase.to_origin(ν::ReshapedDistribution, y) = reshape(y, size(ν.dist)) + +MeasureBase.from_origin(ν::ReshapedDistribution, x) = reshape(x, ν.dims) diff --git a/src/standard_dist.jl b/src/standard_dist.jl new file mode 100644 index 0000000..dbe3ccf --- /dev/null +++ b/src/standard_dist.jl @@ -0,0 +1,194 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +""" + struct StandardDist{D<:Distribution{Univariate,Continuous},N} <: Distributions.Distribution{ArrayLikeVariate{N},Continuous} + +Represents `D()` or a product distribution of `D()` in a dispatchable fashion. + +Constructor: +``` + StandardDist{Uniform}(size...) + StandardDist{Normal}(size...) +``` +""" +struct StandardDist{D<:Distribution{Univariate,Continuous},N,U<:Integer} <: Distributions.Distribution{ArrayLikeVariate{N},Continuous} + _size::NTuple{N,U} +end +export StandardDist + +StandardDist{D}() where {D<:Distribution{Univariate,Continuous}} = StandardDist{D,0,StaticInt{1}}(()) +StandardDist{D}(dims::Vararg{U,N}) where {D<:Distribution{Univariate,Continuous},N,U<:Integer} = StandardDist{D,N,U}((dims...,)) + + +const StandardUnivariateDist{D<:Distribution{Univariate,Continuous},U<:Integer} = StandardDist{D,0,U} +const StandardMultivariteDist{D<:Distribution{Univariate,Continuous},U<:Integer} = StandardDist{D,1,U} + + +function Base.show(io::IO, d::StandardDist{D}) where D + print(io, nameof(typeof(d)), "{", D, "}") + show(io, d._size) +end + + +@inline MeasureBase.transport_def(::MU, μ::MU, x) where {MU<:StandardDist{<:Any,0}} = x + +for (A, B) in [ + (Uniform, StdUniform), + (Exponential, StdExponential), + (Logistic, StdLogistic), + (Normal, StdNormal) +] + @eval begin + @inline MeasureBase.transport_origin(d::StandardDist{$A,0}) = $B() + @inline MeasureBase.transport_origin(d::StandardDist{$A,N}) where N = $B()^size(d) + end +end + +@inline MeasureBase.to_origin(ν::StandardDist, y) = y +@inline MeasureBase.from_origin(ν::StandardDist, x) = x + + +@inline nonstddist(::StandardDist{D,0}) where D = D(Distributions.params(D())...) +@inline function nonstddist(d::StandardDist{D,N}) where {D,N} + nonstd0 = nonstddist(StandardDist{D}()) + reshape(Distributions.product_distribution(fill(nonstd0, length(d))), size(d)) +end + + +(::Type{D}, d::StandardDist{D,0}) where {D<:Distribution{Univariate,Continuous}} = nonstddist(d) + +# TODO: Replace `fill` by `FillArrays.Fill` once Distributions fully supports this: +(::Type{Distributions.Product})(d::StandardDist{D,1}) where D = Distributions.Product(fill(StandardDist{D}(), length(d))) + +Base.convert(::Type{D}, d::StandardDist{D,0}) where {D<:Distribution{Univariate,Continuous}} = D(d) +Base.convert(::Type{Distributions.Product}, d::StandardDist{D,1}) where D = Distributions.Product(d) + + + +@inline Base.size(d::StandardDist) = d._size +@inline Base.length(d::StandardDist) = prod(size(d)) + +Base.eltype(::Type{StandardDist{D,N}}) where {D,N} = Float64 + +@inline Distributions.partype(d::StandardDist{D}) where D = Float64 + +@inline StatsBase.params(d::StandardDist) = () + +for f in ( + :(Base.minimum), + :(Base.maximum), + :(Statistics.mean), + :(Statistics.median), + :(StatsBase.mode), + :(Statistics.var), + :(Statistics.std), + :(StatsBase.skewness), + :(StatsBase.kurtosis), + :(Distributions.location), + :(Distributions.scale), +) + @eval begin + ($f)(d::StandardDist{D,0}) where D = ($f)(nonstddist(d)) + ($f)(d::StandardDist{D,N}) where {D,N} = Fill(($f)(StandardDist{D}()), size(d)...) + end +end + +StatsBase.modes(d::StandardDist) = [StatsBase.mode(d)] + +# ToDo: Define cov for N!=1? +Statistics.cov(d::StandardDist{D,1}) where D = Diagonal(Statistics.var(d)) +Distributions.invcov(d::StandardDist{D,1}) where D = Diagonal(Fill(inv(Statistics.var(StandardDist{D}())), length(d))) +Distributions.logdetcov(d::StandardDist{D,1}) where D = length(d) * log(Statistics.var(StandardDist{D}())) + +StatsBase.entropy(d::StandardDist{D,0}) where D = StatsBase.entropy(nonstddist(d)) +StatsBase.entropy(d::StandardDist{D,N}) where {D,N} = length(d) * StatsBase.entropy(StandardDist{D}()) + + +Distributions.insupport(d::StandardDist{D,0}, x::Real) where D = Distributions.insupport(nonstddist(d), x) + +function Distributions.insupport(d::StandardDist{D,N}, x::AbstractArray{<:Real,N}) where {D,N} + all(Base.Fix1(Distributions.insupport, StandardDist{D}()), checked_arg(d, x)) +end + + +@inline Distributions.logpdf(d::StandardDist{D,0}, x::U) where {D,U} = Distributions.logpdf(nonstddist(d), x) + +function Distributions.logpdf(d::StandardDist{D,N}, x::AbstractArray{<:Real,N}) where {D,N} + Distributions._logpdf(d, checked_arg(d, x)) +end + +function Distributions._logpdf(::StandardDist{D,1}, x::AbstractArray{<:Real,1}) where D + sum(Base.Fix1(Distributions.logpdf, StandardDist{D}()), x) +end + +function Distributions._logpdf(::StandardDist{D,2}, x::AbstractArray{<:Real,2}) where D + sum(Base.Fix1(Distributions.logpdf, StandardDist{D}()), x) +end + +function Distributions._logpdf(::StandardDist{D,N}, x::AbstractArray{<:Real,N}) where {D,N} + sum(Base.Fix1(Distributions.logpdf, StandardDist{D}()), x) +end + + + +Distributions.gradlogpdf(d::StandardDist{D,0}, x::Real) where D = Distributions.gradlogpdf(nonstddist(d), x) + +function Distributions.gradlogpdf(d::StandardDist{D,N}, x::AbstractArray{<:Real,N}) where {D,N} + Distributions.gradlogpdf.(StandardDist{D}(), checked_arg(d, x)) +end + + +#@inline Distributions.pdf(d::StandardDist{D,0}, x::U) where {D,U} = pdf(nonstddist(d), x) + +function Distributions.pdf(d::StandardDist{D,1}, x::AbstractVector{U}) where {D,U<:Real} + Distributions._pdf(d, checked_arg(d, x)) +end + +function Distributions._pdf(d::StandardDist{D,1}, x::AbstractVector{U}) where {D,U<:Real} + exp(Distributions._logpdf(d, x)) +end + +function Distributions.pdf(d::StandardDist{D,2}, x::AbstractMatrix{U}) where {D,U<:Real} + Distributions._pdf(d, checked_arg(d, x)) +end + +function Distributions._pdf(d::StandardDist{D,2}, x::AbstractMatrix{U}) where {D,U<:Real} + exp(Distributions._logpdf(d, x)) +end + +function Distributions.pdf(d::StandardDist{D,N}, x::AbstractArray{U,N}) where {D,N,U<:Real} + Distributions._pdf(d, checked_arg(d, x)) +end + +function Distributions._pdf(d::StandardDist{D,N}, x::AbstractArray{U,N}) where {D,N,U<:Real} + exp(Distributions._logpdf(d, x)) +end + + +for f in ( + :(Distributions.logcdf), + :(Distributions.cdf), + :(Distributions.logccdf), + :(Distributions.ccdf), + :(Distributions.quantile), + :(Distributions.cquantile), + :(Distributions.invlogcdf), + :(Distributions.invlogccdf), + :(Distributions.mgf), + :(Distributions.cf), +) + @eval begin + @inline ($f)(d::StandardDist, x::Real) = ($f)(nonstddist(d), x) + end +end + + +Base.rand(rng::AbstractRNG, d::StandardDist{D,0}) where D = rand(rng, nonstddist(d)) +Random.rand!(rng::AbstractRNG, d::StandardDist{D,0}, x::AbstractArray{<:Real,0}) where D = (x[] = rand(rng, d); return x) +Random.rand!(rng::AbstractRNG, d::StandardDist{D,N}, x::AbstractArray{<:Real,N}) where {D,N} = rand!(rng, StandardDist{D}(), x) + + +Distributions.truncated(d::StandardDist{D,0}, l::Real, u::Real) where D = Distributions.truncated(nonstddist(d), l, u) + +Distributions.product_distribution(dists::AbstractVector{<:StandardDist{D,0}}) where D = StandardDist{D}(size(dists)...) +Distributions.product_distribution(dists::AbstractArray{<:StandardDist{D,0}}) where D = StandardDist{D}(size(dists)...) diff --git a/src/standard_normal.jl b/src/standard_normal.jl new file mode 100644 index 0000000..cc3d981 --- /dev/null +++ b/src/standard_normal.jl @@ -0,0 +1,73 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +""" + const StandardNormal{N} = StandardDist{Normal,N} + +The univariate standard normal distribution. +""" +const StandardNormal{N} = StandardDist{Normal,N} +export StandardNormal + +Distributions.Normal(d::StandardDist{Normal,0}) = Distributions.Normal() +Base.convert(::Type{Distributions.Normal}, d::StandardDist{Normal,1}) = Distributions.Normal(d) + +Distributions.MvNormal(d::StandardDist{Normal,1}) = MvNormal(PDMats.ScalMat(length(d), 1)) +Base.convert(::Type{Distributions.MvNormal}, d::StandardDist{Normal,1}) = Distributions.MvNormal(d) + +Base.minimum(d::StandardDist{Normal,0}) = -Inf +Base.maximum(d::StandardDist{Normal,0}) = +Inf + +Distributions.insupport(d::StandardDist{Normal,0}, x::Real) = !isnan(x) + +Distributions.location(d::StandardDist{Normal,0}) = Statistics.mean(d) +Distributions.scale(d::StandardDist{Normal,0}) = Statistics.var(d) + +Statistics.mean(d::StandardDist{Normal,0}) = 0 +Statistics.mean(d::StandardDist{Normal,N}) where N = Zeros{Int}(size(d)...) + +StatsBase.median(d::StandardDist{Normal}) = Statistics.mean(d) +StatsBase.mode(d::StandardDist{Normal}) = Statistics.mean(d) + +StatsBase.modes(d::StandardDist{Normal,0}) = Zeros{Int}(1) + +Statistics.var(d::StandardDist{Normal,0}) = 1 +Statistics.var(d::StandardDist{Normal,N}) where N = Ones{Int}(size(d)...) + +StatsBase.std(d::StandardDist{Normal,0}) = 1 +StatsBase.std(d::StandardDist{Normal,N}) where N = Ones{Int}(size(d)...) + +StatsBase.skewness(d::StandardDist{Normal,0}) = 0 +StatsBase.kurtosis(d::StandardDist{Normal,0}) = 0 + +StatsBase.entropy(d::StandardDist{Normal,0}) = muladd(log2π, 1/2, 1/2) + +Distributions.logpdf(d::StandardDist{Normal,0}, x::U) where {U<:Real} = muladd(abs2(x), -U(1)/U(2), -log2π/U(2)) +Distributions.pdf(d::StandardDist{Normal,0}, x::U) where {U<:Real} = invsqrt2π * exp(-abs2(x)/U(2)) + +@inline Distributions.gradlogpdf(d::StandardDist{Normal,0}, x::Real) = -x + +@inline Distributions.logcdf(d::StandardDist{Normal,0}, x::Real) = StatsFuns.normlogcdf(x) +@inline Distributions.cdf(d::StandardDist{Normal,0}, x::Real) = StatsFuns.normcdf(x) +@inline Distributions.logccdf(d::StandardDist{Normal,0}, x::Real) = StatsFuns.normlogccdf(x) +@inline Distributions.ccdf(d::StandardDist{Normal,0}, x::Real) = StatsFuns.normccdf(x) +@inline Distributions.quantile(d::StandardDist{Normal,0}, p::Real) = StatsFuns.norminvcdf(p) +@inline Distributions.cquantile(d::StandardDist{Normal,0}, p::Real) = StatsFuns.norminvccdf(p) +@inline Distributions.invlogcdf(d::StandardDist{Normal,0}, p::Real) = StatsFuns.norminvlogcdf(p) +@inline Distributions.invlogccdf(d::StandardDist{Normal,0}, p::Real) = StatsFuns.norminvlogccdf(p) + +Base.rand(rng::AbstractRNG, d::StandardDist{Normal,0}) = randn(rng) +Base.rand(rng::AbstractRNG, d::StandardDist{Normal,N}) where N = randn(rng, size(d)...) +Random.rand!(rng::AbstractRNG, d::StandardDist{Normal,N}, x::AbstractArray{<:Real,N}) where {D,N} = Random.randn!(rng, x) + +Distributions.invcov(d::StandardDist{Normal,1}) = Distributions.cov(d) +Distributions.logdetcov(d::StandardDist{Normal,1}) = 0 + + +function Distributions.sqmahal(d::StandardDist{Normal,N}, x::AbstractArray{<:Real,N}) where N + dot(x, checked_arg(d, x)) +end + +function Distributions. sqmahal!(r::AbstractVector, d::StandardDist{Normal,N}, x::AbstractMatrix) where N + x_cols = eachcol(checked_arg(d, first(eachcol(x)))) + r .= dot.(x_cols, x_cols) +end diff --git a/src/standard_uniform.jl b/src/standard_uniform.jl new file mode 100644 index 0000000..dbde74f --- /dev/null +++ b/src/standard_uniform.jl @@ -0,0 +1,75 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +""" + const StandardUniform{N} = StandardDist{Uniform,N} + +The univariate standard uniform distribution. +""" +const StandardUniform{N} = StandardDist{Uniform,N} +export StandardUniform + +Distributions.Uniform(d::StandardDist{Uniform,0}) = Distributions.Uniform() +Base.convert(::Type{Distributions.Uniform}, d::StandardDist{Uniform,1}) = Distributions.Uniform(d) + +Base.minimum(::StandardDist{Uniform,0}) = 0 +Base.maximum(::StandardDist{Uniform,0}) = 1 + +Distributions.location(::StandardDist{Uniform,0}) = 0 +Distributions.scale(::StandardDist{Uniform,0}) = 1 + +Statistics.mean(d::StandardDist{Uniform,0}) = 1//2 +StatsBase.median(d::StandardDist{Uniform,0}) = Statistics.mean(d) +StatsBase.mode(d::StandardDist{Uniform,0}) = Statistics.mean(d) +StatsBase.modes(d::StandardDist{Uniform,0}) = Zeros{Int}(0) +StatsBase.modes(d::StandardDist{Uniform,N}) where N = Fill(Zeros{Int}(size(d))) + +Statistics.var(d::StandardDist{Uniform,0}) = 1//12 +StatsBase.std(d::StandardDist{Uniform,0}) = sqrt(Statistics.var(d)) +StatsBase.skewness(d::StandardDist{Uniform,0}) = 0 +StatsBase.kurtosis(d::StandardDist{Uniform,0}) = -6//5 + +StatsBase.entropy(d::StandardDist{Uniform,0}) = 0 + + +function Distributions.logpdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} + ifelse(Distributions.insupport(d, x), U(0), U(-Inf)) +end + +function Distributions.pdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} + ifelse(Distributions.insupport(d, x), one(U), zero(U)) +end + + +Distributions.logcdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} = log(Distributions.cdf(d, x)) + +function Distributions.cdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} + ifelse(x < zero(U), zero(U), ifelse(x < one(U), x, one(U))) +end + +Distributions.logccdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} = log(Distributions.ccdf(d, x)) + +Distributions.ccdf(d::StandardDist{Uniform,0}, x::U) where {U<:Real} = one(x) - Distributions.cdf(d, x) + + +function Distributions.quantile(d::StandardDist{Uniform,0}, p::U) where {U<:Real} + convert(float(U), p) +end + +function Distributions.cquantile(d::StandardDist{Uniform,0}, p::U) where {U<:Real} + y = Distributions.quantile(d, p) + one(y) - y +end + + +Distributions.mgf(d::StandardDist{Uniform,0}, t::Real) = Distributions.mgf(nonstddist(d), t) +Distributions.cf(d::StandardDist{Uniform,0}, t::Real) = Distributions.cf(nonstddist(d), t) + +Distributions.gradlogpdf(d::StandardDist{Uniform,0}, x::Real) = zero(x) + +function Distributions.gradlogpdf(d::StandardDist{Uniform,N}, x::AbstractArray{<:Real,N}) where N + zero(checked_arg(d, x)) +end + +Base.rand(rng::AbstractRNG, d::StandardDist{Uniform,0}) = rand(rng) +Base.rand(rng::AbstractRNG, d::StandardDist{Uniform,N}) where N = rand(rng, size(d)...) +Random.rand!(rng::AbstractRNG, d::StandardDist{Uniform,N}, x::AbstractArray{<:Real,N}) where {D,N} = rand!(rng, x) diff --git a/src/standardmv.jl b/src/standardmv.jl new file mode 100644 index 0000000..47634a6 --- /dev/null +++ b/src/standardmv.jl @@ -0,0 +1,18 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + + +MeasureBase.getdof(ν::MvNormal) = length(ν) + +MeasureBase.transport_origin(ν::MvNormal) = StandardDist{Normal}(length(ν)) + +function MeasureBase.from_origin(ν::MvNormal, x) + A = cholesky(ν.Σ).L + b = ν.μ + muladd(A, x, b) +end + +function MeasureBase.to_origin(ν::MvNormal, y) + A = cholesky(ν.Σ).L + b = ν.μ + A \ (y - b) +end diff --git a/src/stdnormal_measure.jl b/src/stdnormal_measure.jl new file mode 100644 index 0000000..287b20b --- /dev/null +++ b/src/stdnormal_measure.jl @@ -0,0 +1,18 @@ +struct StdNormal <: MeasureBase.StdMeasure end + +export StdNormal + +@inline MeasureBase.insupport(d::StdNormal, x) = true + +@inline MeasureBase.logdensity_def(::StdNormal, x) = -x^2 / 2 +@inline MeasureBase.basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), Lebesgue(ℝ)) + +@inline MeasureBase.getdof(::StdNormal) = static(1) + +@inline MeasureBase.transport_def(::StdUniform, μ::StdNormal, x) = StatsFuns.normcdf(x) +@inline MeasureBase.transport_def(::StdNormal, μ::StdUniform, x) = StatsFuns.norminvcdf(x) + +@inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdNormal) where {T} = randn(rng, T) + + +@inline MeasureBase.StdMeasure(::typeof(randn)) = StdNormal() diff --git a/src/univariate.jl b/src/univariate.jl new file mode 100644 index 0000000..55a255c --- /dev/null +++ b/src/univariate.jl @@ -0,0 +1,176 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + + +@inline MeasureBase.getdof(::Distribution{Univariate}) = static(1) + +@inline MeasureBase.check_dof(a::Distribution{Univariate}, b::Distribution{Univariate}) = nothing + + +# Use ForwardDiff for univariate transformations: +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::MeasureBase.StdMeasure, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::MeasureBase.StdMeasure, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end + + +# Generic transformations to/from StdUniform via cdf/quantile: + + +_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...) + + +@inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) + +@inline _trafo_cdf_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, x::Real) = Distributions.cdf(d, x) + +@inline function _trafo_cdf_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, x::ForwardDiff.Dual{TAG}) where {N,TAG} + x_v = ForwardDiff.value(x) + u = Distributions.cdf(d, x_v) + dudx = Distributions.pdf(d, x_v) + ForwardDiff.Dual{TAG}(u, dudx * ForwardDiff.partials(x)) +end + + +@inline _trafo_quantile(d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl(_dist_params_numtype(d), d, u) + +@inline _trafo_quantile_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl_generic(d, u) + +@inline function _trafo_quantile_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, u::ForwardDiff.Dual{TAG}) where {TAG} + x = _trafo_quantile_impl_generic(d, ForwardDiff.value(u)) + dxdu = inv(Distributions.pdf(d, x)) + ForwardDiff.Dual{TAG}(x, dxdu * ForwardDiff.partials(u)) +end + + +@inline _trafo_quantile_impl_generic(d::Distribution{Univariate,Continuous}, u::Real) = Distributions.quantile(d, u) + +# Workaround for Beta dist, ForwardDiff doesn't work for parameters: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Real) where {T<:ForwardDiff.Dual} = convert(float(typeof(u)), NaN) +# Workaround for Beta dist, current quantile implementation only supports Float64: +@inline function _trafo_quantile_impl_generic(d::Beta{T}, u::Union{Integer,AbstractFloat}) where {T<:Union{Integer,AbstractFloat}} + Distributions.quantile(d, convert(promote_type(Float64, typeof(u)), u)) +end + +#= +# ToDo: + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end +=# + + +@inline function _result_numtype(d::Distribution{Univariate}, x::T) where {T<:Real} + float(promote_type(T, eltype(Distributions.params(d)))) + # firsttype(first(typeof(x), promote_type(map(eltype, Distributions.params(d))...))) +end + + +@inline function MeasureBase.transport_def(::StdUniform, μ::Distribution{Univariate,Continuous}, x) + R = _result_numtype(μ, x) + if Distributions.insupport(μ, x) + y = _trafo_cdf(μ, x) + convert(R, y) + else + convert(R, NaN) + end +end + + +@inline function MeasureBase.transport_def(ν::Distribution{Univariate,Continuous}, ::StdUniform, x::T) where T + R = _result_numtype(ν, x) + TF = float(T) + if 0 <= x <= 1 + # Avoid x ≈ 0 and x ≈ 1 to avoid infinite variate values for target distributions with infinite support: + mod_x = ifelse(x == 0, zero(TF) + eps(TF), ifelse(x == 1, one(TF) - eps(TF), convert(TF, x))) + y = _trafo_quantile(ν, mod_x) + convert(R, y) + else + convert(R, NaN) + end +end + + +# Use standard measures as transformation origin for scaled/translated equivalents: + +function _origin_to_affine(ν::Distribution{Univariate}, y::T) where {T<:Real} + trg_offs, trg_scale = Distributions.location(ν), Distributions.scale(ν) + x = muladd(y, trg_scale, trg_offs) + convert(_result_numtype(ν, y), x) +end + +function _affine_to_origin(μ::Distribution{Univariate}, x::T) where {T<:Real} + src_offs, src_scale = Distributions.location(μ), Distributions.scale(μ) + y = (x - src_offs) / src_scale + convert(_result_numtype(μ, x), y) +end + +for (A, B) in [ + (Uniform, StdUniform), + (Logistic, StdLogistic), + (Normal, StdNormal) +] + @eval begin + @inline MeasureBase.transport_origin(::$A) = $B() + @inline MeasureBase.to_origin(ν::$A, y) = _affine_to_origin(ν, y) + @inline MeasureBase.from_origin(ν::$A, x) = _origin_to_affine(ν, x) + end +end + +@inline MeasureBase.transport_origin(::Exponential) = StdExponential() +@inline MeasureBase.to_origin(ν::Exponential, y) = Distributions.scale(ν) \ y +@inline MeasureBase.from_origin(ν::Exponential, x) = Distributions.scale(ν) * x + + + +# Transform between univariate and single-element power measure + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) +end + +function MeasureBase.transport_def(ν::PowerMeasure{<:StdMeasure}, μ::Distribution{Univariate}, x) + return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +end + + +# Transform between univariate and single-element standard multivariate + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::StandardDist{D,1}, x) where D + return transport_def(ν, StandardDist{D}(), only(x)) +end + +function MeasureBase.transport_def(ν::StandardDist{D,1}, μ::Distribution{Univariate}, x) where D + return Fill(transport_def(StandardDist{D}(), μ, only(x)), size(ν)...) +end diff --git a/src/utils.jl b/src/utils.jl index 8f6ee58..e3f2756 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,32 @@ -@inline _convert_numtype(::Type{T}, x::T) where {T<:Real} = x -@inline _convert_numtype(::Type{T}, x::AbstractArray{T}) where {T<:Real} = x -@inline _convert_numtype(::Type{T}, x::U) where {T<:Real,U<:Real} = T(X) -_convert_numtype(::Type{T}, x::AbstractArray{U}) where {T<:Real,U<:Real} = T.(x) -_convert_numtype(::Type{T}, x) where {T<:Real} = fmap(elem -> _convert_numtype(T, elem), x) +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + + +""" + convert_realtype(::Type{T}, x) where {T<:Real} + +Convert x to use `T` as it's underlying type for real numbers. +""" +function convert_realtype end + +_convert_realtype_pullback(ΔΩ) = NoTangent(), NoTangent, ΔΩ +ChainRulesCore.rrule(::typeof(convert_realtype), ::Type{T}, x) where T = convert_realtype(T, x), _convert_realtype_pullback + +@inline convert_realtype(::Type{T}, x::T) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::AbstractArray{T}) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::U) where {T<:Real,U<:Real} = T(x) +convert_realtype(::Type{T}, x::AbstractArray{U}) where {T<:Real,U<:Real} = T.(x) +convert_realtype(::Type{T}, x) where {T<:Real} = fmap(elem -> convert_realtype(T, elem), x) + + +""" + DistributionMeasures.firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} + +Return the first type, but as a dual number type if the second one is dual. + +If `U <: ForwardDiff.Dual{tag,<:Real,N}`, returns `ForwardDiff.Dual{tag,T,N}`, +otherwise returns `T` +""" +function firsttype end + +firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} = T +firsttype(::Type{T}, ::Type{<:ForwardDiff.Dual{tag,<:Real,N}}) where {T<:Real,tag,N} = ForwardDiff.Dual{tag,T,N} diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..aacf49d --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,16 @@ +[deps] +ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/getjacobian.jl b/test/getjacobian.jl new file mode 100644 index 0000000..87de7b8 --- /dev/null +++ b/test/getjacobian.jl @@ -0,0 +1,34 @@ +# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). + +import ForwardDiff + +torv_and_back(V::AbstractVector{<:Real}) = V, identity +torv_and_back(x::Real) = [x], V -> V[1] +torv_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2]) +torv_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N)) + +function torv_and_back(x::Ref) + xval = x[] + V, to_xval = torv_and_back(xval) + back_to_ref(V) = Ref(to_xval(V)) + return (V, back_to_ref) +end + +torv_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A)) + +function torv_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N} + RA = cat(real.(A), imag.(A), dims = N+1) + V, to_array = torv_and_back(RA) + function back_to_complex(V) + RA = to_array(V) + Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2)) + end + return (V, back_to_complex) +end + + +function getjacobian(f, x) + V, to_x = torv_and_back(x) + vf(V) = torv_and_back(f(to_x(V)))[1] + ForwardDiff.jacobian(vf, V) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2fd851e..e6b14ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,5 +2,11 @@ using DistributionMeasures using Test @testset "DistributionMeasures.jl" begin - # Write your tests here. + include("test_autodiff_utils.jl") + include("test_measure_interface.jl") + include("test_distribution_measure.jl") + include("test_standard_dist.jl") + include("test_standard_uniform.jl") + include("test_standard_normal.jl") + include("test_transport.jl") end diff --git a/test/test_autodiff_utils.jl b/test/test_autodiff_utils.jl new file mode 100644 index 0000000..af1e49f --- /dev/null +++ b/test/test_autodiff_utils.jl @@ -0,0 +1,19 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + + +@testset "trafo_utils" begin + xs = rand(5) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushfront(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushfront(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushback(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushback(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._rev_cumsum, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._rev_cumsum, xs) + @test Zygote.jacobian(DistributionMeasures._exp_cumsum_log, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._exp_cumsum_log, xs) ≈ ForwardDiff.jacobian(cumprod, xs) +end diff --git a/test/test_distribution_measure.jl b/test/test_distribution_measure.jl new file mode 100644 index 0000000..b797624 --- /dev/null +++ b/test/test_distribution_measure.jl @@ -0,0 +1,54 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +using Distributions: Distribution +import MeasureBase +using MeasureBase: AbstractMeasure + +@testset "Measure interface" begin + d = Distributions.Weibull() + @test @inferred(AbstractMeasure(d)) isa AbstractMeasure + @test @inferred(AbstractMeasure(d)) isa DistributionMeasure + @test @inferred(convert(AbstractMeasure, d)) isa AbstractMeasure + @test @inferred(convert(AbstractMeasure, d)) isa DistributionMeasure + @test @inferred(Distribution(AbstractMeasure(d))) === d + @test @inferred(convert(Distribution, convert(AbstractMeasure, d))) === d + + + c0 = AbstractMeasure(Distributions.Weibull(0.7, 1.3)) + c1 = AbstractMeasure(Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1])) + + d0 = AbstractMeasure(Distributions.Poisson(0.7)) + d1 = AbstractMeasure(Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4]))) + + for μ in [c0, c1, d0, d1] + d = Distribution(μ) + x = rand(μ) + @test @inferred(MeasureBase.logdensity_def(μ, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(μ, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end +end diff --git a/test/test_measure_interface.jl b/test/test_measure_interface.jl new file mode 100644 index 0000000..e7488a8 --- /dev/null +++ b/test/test_measure_interface.jl @@ -0,0 +1,44 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +import MeasureBase + +@testset "Measure interface" begin + c0 = Distributions.Weibull(0.7, 1.3) + c1 = Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1]) + + d0 = Distributions.Poisson(0.7) + d1 = Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4])) + + for d in [c0, c1, d0, d1] + x = rand(d) + @test @inferred(MeasureBase.logdensity_def(d, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(d, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end + + @test MeasureBase.∫(x -> Distributions.Normal(x, 0), Distributions.Normal()) isa MeasureBase.DensityMeasure +end diff --git a/test/test_standard_dist.jl b/test/test_standard_dist.jl new file mode 100644 index 0000000..bd6773c --- /dev/null +++ b/test/test_standard_dist.jl @@ -0,0 +1,129 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +import ForwardDiff, ChainRulesTestUtils + + +@testset "standard_dist" begin + stblrng() = StableRNG(789990641) + + for (D, sz, dref) in [ + (Uniform, (), Uniform()), + (Uniform, (5,), product_distribution(fill(Uniform(0.0, 1.0), 5))), + (Uniform, (2, 3), reshape(product_distribution(fill(Uniform(0.0, 1.0), 6)), 2, 3)), + (Normal, (), Normal()), + (Normal, (), Normal(0., 1.0)), + (Normal, (5,), MvNormal(Diagonal(fill(1.0, 5)))), + (Normal, (2, 3), reshape(MvNormal(Diagonal(fill(1.0, 6))), 2, 3)), + (Exponential, (), Exponential()), + (Exponential, (5,), product_distribution(fill(Exponential(1.0), 5))), + (Exponential, (2, 3), reshape(product_distribution(fill(Exponential(1.0), 6)), 2, 3)), + ] + @testset "StandardDist{$D}($(join(sz,",")))" begin + N = length(sz) + + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + + d = StandardDist{D}(sz...) + + if size(d) == () + @test @inferred(DistributionMeasures.nonstddist(d)) == dref + end + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + for f in [minimum, maximum, mean, median, mode, modes, var, std, skewness, kurtosis, location, scale, entropy] + supported_by_dref = try f(dref); true catch MethodError; false; end + if supported_by_dref + @test @inferred(f(d)) ≈ f(dref) + end + end + + for x in [rand(dref) for i in 1:10] + ref_gradlogpdf = try + gradlogpdf(dref, x) + catch MethodError + ForwardDiff.gradient(x -> logpdf(dref, x), x) + end + @test @inferred(gradlogpdf(d, x)) ≈ ref_gradlogpdf + @test @inferred(logpdf(d, x)) ≈ logpdf(dref, x) + @test @inferred(pdf(d, x)) ≈ pdf(dref, x) + end + + if size(d) == () + for x in [minimum(dref), quantile(dref, 1//3), quantile(dref, 1//2), quantile(dref, 2//3), maximum(dref)] + for f in [logpdf, pdf, gradlogpdf, logcdf, cdf, logccdf, ccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in [0, 1//3, 1//2, 2//3, 1] + for f in [quantile, cquantile] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in log.([0, 1//3, 1//2, 2//3, 1]) + for f in [invlogcdf, invlogccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test isapprox(@inferred(mgf(d, t)), mgf(dref, t), rtol = 1e-5) + @test isapprox(@inferred(cf(d, t)), cf(dref, t), rtol = 1e-5) + end + + @test @inferred(truncated(d, quantile(dref, 1//3), quantile(dref, 2//3))) == truncated(dref, quantile(dref, 1//3), quantile(dref, 2//3)) + + @test @inferred(product_distribution(fill(d, 3))) == StandardDist{typeof(d)}(3) + @test @inferred(product_distribution(fill(d, 3, 4))) == StandardDist{typeof(d)}(3, 4) + end + + if length(size(d)) == 1 + @test @inferred(convert(Distributions.Product, d)) isa Distributions.Product + d_as_prod = convert(Distributions.Product, d) + @test d_as_prod.v == fill(StandardDist{D}(), size(d)...) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), d, 5) + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + @test @inferred(rand!(stblrng(), d, zeros(size(d)...))) == rand!(stblrng(), dref, zeros(size(dref)...)) + if length(size(d)) == 1 + @test @inferred(rand!(stblrng(), d, zeros(size(d)..., 5))) == rand!(stblrng(), dref, zeros(size(dref)..., 5)) + end + end + end + + @testset "StandardDist{Normal}()" begin + # TODO: Add @inferred + d = StandardDist{Normal}(4) + d_uv = StandardDist{Normal}() + dref = MvNormal(Diagonal(fill(1.0, 4))) + @test (MvNormal(d)) == dref + @test (Base.convert(MvNormal, d)) == dref + end +end diff --git a/test/test_standard_normal.jl b/test/test_standard_normal.jl new file mode 100644 index 0000000..8e21db6 --- /dev/null +++ b/test/test_standard_normal.jl @@ -0,0 +1,130 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs + + +@testset "StandardDist{Normal}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Normal,0}" begin + @test @inferred(Normal(StandardDist{Normal}())) isa Normal{Float64} + @test @inferred(Normal(StandardDist{Normal}())) == Normal() + @test @inferred(convert(Normal, StandardDist{Normal}())) == Normal() + + d = StandardDist{Normal}() + dref = Normal() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) == var(dref) + @test @inferred(std(d)) == std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) == kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-Inf, -1.3, 0.0, 1.3, +Inf] + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Normal}(), -2.2f0, 3.1f0)) isa Truncated{Normal{Float64}} + @test truncated(StandardDist{Normal}(), -2.2f0, 3.1f0) == truncated(Normal(0.0, 1.0), -2.2f0, 3.1f0) + + @test @inferred(product_distribution(fill(StandardDist{Normal}(), 3))) isa StandardDist{Normal,1} + @test product_distribution(fill(StandardDist{Normal}(), 3)) == StandardDist{Normal}(3) + end + + + @testset "StandardDist{Normal,1}" begin + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + + @test @inferred(MvNormal(StandardDist{Normal}(3))) isa MvNormal{Int} + @test @inferred(MvNormal(StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + @test @inferred(convert(MvNormal, StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + + d = StandardDist{Normal}(3) + dref = MvNormal(ScalMat(3, 1.0)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) == var(dref) + @test @inferred(cov(d)) == cov(dref) + + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) == modes(dref) + + @test @inferred(invcov(d)) == invcov(dref) + @test @inferred(logdetcov(d)) == logdetcov(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + # Distributions.insupport is inconsistent at +- Inf between Normal and MvNormal + if !any(isinf, x) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + end + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(sqmahal(d, x)) == sqmahal(dref, x) + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/test_standard_uniform.jl b/test/test_standard_uniform.jl new file mode 100644 index 0000000..f742d57 --- /dev/null +++ b/test/test_standard_uniform.jl @@ -0,0 +1,119 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +using FillArrays +using ForwardDiff + + +@testset "StandardDist{Uniform}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Uniform,0}" begin + @test @inferred(Uniform(StandardDist{Uniform}())) isa Uniform{Float64} + @test @inferred(Uniform(StandardDist{Uniform}())) == Uniform() + @test @inferred(convert(Uniform, StandardDist{Uniform}())) == Uniform() + + d = StandardDist{Uniform}() + dref = Uniform() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(std(d)) ≈ std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) ≈ kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-0.5, 0.0, 0.25, 0.75, 1.0, 1.5] + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0)) isa Uniform{Float64} + @test truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0) == Uniform(0.0f0, 0.7f0) + @test truncated(StandardDist{Uniform}(), 0.2f0, 0.7f0) == Uniform(0.2f0, 0.7f0) + + @test @inferred(product_distribution(fill(StandardDist{Uniform}(), 3))) isa DistributionMeasures.StandardDist{Uniform,1} + @test product_distribution(fill(StandardDist{Uniform}(), 3)) == DistributionMeasures.StandardDist{Uniform}(3) + end + + + @testset "StandardDist{Uniform,1}" begin + d = DistributionMeasures.StandardDist{Uniform}(3) + dref = product_distribution(fill(Uniform(), 3)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(cov(d)) ≈ cov(dref) + + @test @inferred(mode(d)) == [0.5, 0.5, 0.5] + @test @inferred(modes(d)) == fill([0, 0,0 ]) + + @test @inferred(invcov(d)) == inv(cov(dref)) + @test @inferred(logdetcov(d)) == logdet(cov(dref)) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(gradlogpdf(d, x)) == ForwardDiff.gradient(x -> logpdf(d, x), x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/test_transport.jl b/test/test_transport.jl new file mode 100644 index 0000000..17ea4f6 --- /dev/null +++ b/test/test_transport.jl @@ -0,0 +1,149 @@ +# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using InverseFunctions, ChangesOfVariables +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + +using MeasureBase: transport_to, transport_def, transport_origin +using MeasureBase: StdExponential +using DistributionMeasures: _trafo_cdf, _trafo_quantile + +include("getjacobian.jl") + + +@testset "test_distribution_transform" begin + function test_back_and_forth(trg, src) + @testset "transform $(typeof(trg).name) <-> $(typeof(src).name)" begin + x = rand(src) + y = transport_def(trg, src, x) + src_v_reco = transport_def(src, trg, y) + + @test x ≈ src_v_reco + + f = x -> transport_def(trg, src, x) + ref_ladj = logpdf(src, x) - logpdf(trg, y) + @test ref_ladj ≈ logabsdet(getjacobian(f, x))[1] + end + end + + reshaped_rand(d::Distribution{Univariate}, n) = rand(d, n) + reshaped_rand(d::Distribution{Multivariate}, n) = nestedview(rand(d, n)) + + function test_dist_trafo_moments(trg, src) + unshaped(x) = first(torv_and_back(x)) + @testset "check moments of trafo $(typeof(trg).name) <- $(typeof(src).name)" begin + X = reshaped_rand(src, 10^5) + Y = transport_to(trg, src).(X) + Y_ref = reshaped_rand(trg, 10^6) + @test isapprox(mean(unshaped.(Y)), mean(unshaped.(Y_ref)), rtol = 0.5) + @test isapprox(cov(unshaped.(Y)), cov(unshaped.(Y_ref)), rtol = 0.5) + end + end + + @testset "transforms-tests" begin + stduvuni = StandardDist{Uniform}() + stduvnorm = StandardDist{Uniform}() + + uniform1 = Uniform(-5.0, -0.01) + uniform2 = Uniform(0.01, 5.0) + + normal1 = Normal(-10, 1) + normal2 = Normal(10, 5) + + stdmvnorm1 = StandardDist{Normal}(1) + stdmvnorm2 = StandardDist{Normal}(2) + + stdmvuni2 = StandardDist{Uniform}(2) + + standnorm2_reshaped = reshape(stdmvnorm2, 1, 2) + + mvnorm = MvNormal([0.3, -2.9], [1.7 0.5; 0.5 2.3]) + beta = Beta(3,1) + gamma = Gamma(0.1,0.7) + dirich = Dirichlet([0.1,4]) + + test_back_and_forth(stduvuni, stduvuni) + test_back_and_forth(stduvnorm, stduvnorm) + test_back_and_forth(stduvuni, stduvnorm) + test_back_and_forth(stduvnorm, stduvuni) + + test_back_and_forth(stdmvuni2, stdmvuni2) + test_back_and_forth(stdmvnorm2, stdmvnorm2) + test_back_and_forth(stdmvuni2, stdmvnorm2) + test_back_and_forth(stdmvnorm2, stdmvuni2) + + test_back_and_forth(beta, stduvnorm) + test_back_and_forth(gamma, stduvnorm) + test_back_and_forth(gamma, beta) + + test_back_and_forth(mvnorm, stdmvuni2) + test_back_and_forth(stdmvuni2, mvnorm) + + test_back_and_forth(mvnorm, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, mvnorm) + test_back_and_forth(stdmvnorm2, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, standnorm2_reshaped) + + test_dist_trafo_moments(normal2, normal1) + test_dist_trafo_moments(uniform2, uniform1) + + test_dist_trafo_moments(beta, stduvnorm) + test_dist_trafo_moments(gamma, stduvnorm) + + test_dist_trafo_moments(mvnorm, stdmvnorm2) + test_dist_trafo_moments(dirich, stdmvnorm1) + + let + mvuni = product_distribution([Uniform(), Uniform()]) + + x = rand() + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + + x = rand(2) + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + end + end + + @testset "Custom cdf and quantile for dual numbers" begin + Dual = ForwardDiff.Dual + + @test _trafo_cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)) == cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)) + @test _trafo_cdf(Normal(0, 1), Dual(0.5, 1)) == cdf(Normal(0, 1), Dual(0.5, 1)) + + @test _trafo_quantile(Normal(0, 1), Dual(0.5, 1)) == quantile(Normal(0, 1), Dual(0.5, 1)) + @test _trafo_quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)) == quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)) + end + + @testset "trafo autodiff pullbacks" begin + x = [0.6, 0.7, 0.8, 0.9] + f = transport_to(Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]), Uniform) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + f = inverse(transport_to(Normal, Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]))) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + end + + + @testset "transport_to autosel" begin + for (M,R) in [ + (StandardNormal, StandardNormal) + (Normal, StandardNormal) + (StandardUniform, StandardUniform) + (Uniform, StandardUniform) + ] + @test @inferred(transport_to(M, Weibull())) == transport_to(R(), Weibull()) + @test @inferred(transport_to(Weibull(), M)) == transport_to(Weibull(), R()) + @test @inferred(transport_to(M, MvNormal(float(I(5))))) == transport_to(R(5), MvNormal(float(I(5)))) + @test @inferred(transport_to(MvNormal(float(I(5))), M)) == transport_to(MvNormal(float(I(5))), R(5)) + @test @inferred(transport_to(M, StdExponential()^(2,3))) == transport_to(R(6), StdExponential()^(2,3)) + @test @inferred(transport_to(StdExponential()^(2,3), M)) == transport_to(StdExponential()^(2,3), R(6)) + end + end +end