Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
indent = 4
margin = 92
always_for_in = true
whitespace_typedefs = false
whitespace_ops_in_indices = false
remove_extra_newlines = true
import_to_using = false
pipe_to_function_call = false
short_to_long_function_def = true
always_use_return = false
whitespace_in_kwargs = true
annotate_untyped_fields_with_any = false
format_docstrings = false
align_struct_field = true
align_conditional = true
align_assignment = true
align_pair_arrow = true
conditional_to_if = true
normalize_line_endings = "unix"
align_matrix = false
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
version:
- '1.6'
- '1.7'
- '1.8'
- 'nightly'
os:
- ubuntu-latest
Expand Down
32 changes: 16 additions & 16 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
using DistributionMeasures
using Documenter

DocMeta.setdocmeta!(DistributionMeasures, :DocTestSetup, :(using DistributionMeasures); recursive=true)
DocMeta.setdocmeta!(
DistributionMeasures,
:DocTestSetup,
:(using DistributionMeasures);
recursive = true,
)

makedocs(;
modules=[DistributionMeasures],
authors="Chad Scherrer <[email protected]> and contributors",
repo="https://github.com/cscherrer/DistributionMeasures.jl/blob/{commit}{path}#{line}",
sitename="DistributionMeasures.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://cscherrer.github.io/DistributionMeasures.jl",
assets=String[],
modules = [DistributionMeasures],
authors = "Chad Scherrer <[email protected]> and contributors",
repo = "https://github.com/cscherrer/DistributionMeasures.jl/blob/{commit}{path}#{line}",
sitename = "DistributionMeasures.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://cscherrer.github.io/DistributionMeasures.jl",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages = ["Home" => "index.md"],
)

deploydocs(;
repo="github.com/cscherrer/DistributionMeasures.jl",
devbranch="main",
)
deploydocs(; repo = "github.com/cscherrer/DistributionMeasures.jl", devbranch = "main")
9 changes: 3 additions & 6 deletions src/DistributionMeasures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import Random
using Random: AbstractRNG, rand!

import DensityInterface
using DensityInterface: logdensityof
using DensityInterface: logdensityof, densityof

import MeasureBase
using MeasureBase: AbstractMeasure, Lebesgue, Counting, ℝ
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic, StdNormal
using MeasureBase: PowerMeasure, WeightedMeasure
using MeasureBase: basemeasure, testvalue
using MeasureBase: getdof, checked_arg
using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin
using MeasureBase: NoTransformOrigin, NoTransport
using MeasureBase: NoTransportOrigin, NoTransport

import Distributions
using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution
Expand Down Expand Up @@ -55,7 +55,6 @@ export MeasureLike
include("utils.jl")
include("autodiff_utils.jl")
include("measure_interface.jl")
include("stdnormal_measure.jl")
include("standard_dist.jl")
include("standard_uniform.jl")
include("standard_normal.jl")
Expand All @@ -67,9 +66,7 @@ include("product.jl")
include("reshaped.jl")
include("dirichlet.jl")

export StdNormal
export DistributionMeasure
export StandardDist


end # module
21 changes: 9 additions & 12 deletions src/autodiff_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,51 @@

@inline _adignore_call(f) = f()
@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent())
ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback
function ChainRulesCore.rrule(::typeof(_adignore_call), f)
_adignore_call(f), _adignore_call_pullback
end

macro _adignore(expr)
:(_adignore_call(() -> $(esc(expr))))
end


function _pushfront(v::AbstractVector, x)
T = promote_type(eltype(v), typeof(x))
r = similar(v, T, length(eachindex(v)) + 1)
r[firstindex(r)] = x
r[firstindex(r)+1:lastindex(r)] = v
r[(firstindex(r)+1):lastindex(r)] = v
r
end

function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
result = _pushfront(v, x)
function _pushfront_pullback(thunked_ΔΩ)
ΔΩ = unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
(NoTangent(), ΔΩ[(firstindex(ΔΩ)+1):lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
end
return result, _pushfront_pullback
end


function _pushback(v::AbstractVector, x)
T = promote_type(eltype(v), typeof(x))
r = similar(v, T, length(eachindex(v)) + 1)
r[lastindex(r)] = x
r[firstindex(r):lastindex(r)-1] = v
r[firstindex(r):(lastindex(r)-1)] = v
r
end

function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
result = _pushback(v, x)
function _pushback_pullback(thunked_ΔΩ)
ΔΩ = unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
(NoTangent(), ΔΩ[firstindex(ΔΩ):(lastindex(ΔΩ)-1)], ΔΩ[lastindex(ΔΩ)])
end
return result, _pushback_pullback
end

_dropfront(v::AbstractVector) = v[(firstindex(v)+1):lastindex(v)]

_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)]

_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1]

_dropback(v::AbstractVector) = v[firstindex(v):(lastindex(v)-1)]

_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs)))

Expand All @@ -61,7 +59,6 @@ function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
return result, _rev_cumsum_pullback
end


# Equivalent to `cumprod(xs)``:
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs)))

Expand Down
1 change: 0 additions & 1 deletion src/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ MeasureBase.getdof(d::Dirichlet) = length(d) - 1

MeasureBase.transport_origin(ν::Dirichlet) = StdUniform()^getdof(ν)


function _dirichlet_beta_trafo(α::Real, β::Real, x::Real)
R = float(promote_type(typeof(α), typeof(β), typeof(x)))
convert(R, transport_def(Beta(α, β), StdUniform(), x))::R
Expand Down
20 changes: 13 additions & 7 deletions src/dist_vartransform.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).

const _AnyStdUniform = Union{StandardUniform, Uniform}
const _AnyStdNormal = Union{StandardNormal, Normal}
const _AnyStdUniform = Union{StandardUniform,Uniform}
const _AnyStdNormal = Union{StdNormal,Normal}

const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal}
const _AnyStdDistribution = Union{_AnyStdUniform,_AnyStdNormal}

_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform
_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal
_std_measure(::Type{<:_AnyStdNormal}) = StdNormal

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M()
_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof)
_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ))
function _std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution}
_std_measure(_std_measure(M), getdof(μ))
end

MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ)
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν))
function MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution}
transport_to(_std_measure_for(NU, μ), μ)
end
function MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution}
transport_to(ν, _std_measure_for(MU, ν))
end
97 changes: 76 additions & 21 deletions src/distribution_measure.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT).


"""
struct DistributionMeasure <: AbstractMeasure

Expand All @@ -13,31 +12,58 @@ to `AbstractMeasure` conversions.
Use `convert(Distribution, m::DistributionMeasure)` or
`Distribution(m::DistributionMeasure)` to convert back to a `Distribution`.
"""
struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} <: AbstractMeasure
struct DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} <:
AbstractMeasure
d::D
end


@inline MeasureBase.AbstractMeasure(d::Distribution) = DistributionMeasure(d)

@inline Base.convert(::Type{AbstractMeasure}, d::Distribution) = DistributionMeasure(d)

@inline Distributions.Distribution(m::DistributionMeasure) = m.d
@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)
@inline function Distributions.Distribution{F}(
m::DistributionMeasure{F},
) where {F<:VariateForm}
Distribution(m)
end
@inline function Distributions.Distribution{F,S}(
m::DistributionMeasure{F,S},
) where {F<:VariateForm,S<:ValueSupport}
Distribution(m)
end

@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m)
@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m)
@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m)

@inline function Base.convert(
::Type{Distribution{F}},
m::DistributionMeasure{F},
) where {F<:VariateForm}
Distribution(m)
end
@inline function Base.convert(
::Type{Distribution{F,S}},
m::DistributionMeasure{F,S},
) where {F<:VariateForm,S<:ValueSupport}
Distribution(m)
end

@inline DensityInterface.densityof(μ::DistributionMeasure) = DensityInterface.densityof(μ.d)
@inline DensityInterface.densityof(μ::DistributionMeasure, x) = DensityInterface.densityof(μ.d, x)
@inline DensityInterface.logdensityof(μ::DistributionMeasure) = DensityInterface.logdensityof(μ.d)
@inline DensityInterface.logdensityof(μ::DistributionMeasure, x) = DensityInterface.logdensityof(μ.d, x)
@inline function DensityInterface.densityof(μ::DistributionMeasure, x)
DensityInterface.densityof(μ.d, x)
end
@inline function DensityInterface.logdensityof(μ::DistributionMeasure)
DensityInterface.logdensityof(μ.d)
end
@inline function DensityInterface.logdensityof(μ::DistributionMeasure, x)
DensityInterface.logdensityof(μ.d, x)
end

@inline MeasureBase.logdensity_def(μ::DistributionMeasure, x) = MeasureBase.logdensity_def(μ.d, x)
@inline MeasureBase.unsafe_logdensityof(μ::DistributionMeasure, x) = MeasureBase.unsafe_logdensityof(μ.d, x)
@inline function MeasureBase.logdensity_def(μ::DistributionMeasure, x)
MeasureBase.logdensity_def(μ.d, x)
end
@inline function MeasureBase.unsafe_logdensityof(μ::DistributionMeasure, x)
MeasureBase.unsafe_logdensityof(μ.d, x)
end
@inline MeasureBase.insupport(μ::DistributionMeasure, x) = MeasureBase.insupport(μ.d, x)
@inline MeasureBase.basemeasure(μ::DistributionMeasure) = MeasureBase.basemeasure(μ.d)
@inline MeasureBase.paramnames(μ::DistributionMeasure) = MeasureBase.paramnames(μ.d)
Expand All @@ -46,30 +72,59 @@ end
@inline MeasureBase.to_origin(::DistributionMeasure, y) = y
@inline MeasureBase.from_origin(::DistributionMeasure, x) = x

function Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real}
convert_realtype(T, rand(m.d))
end

Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.d))

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real}
function _flat_powrand(
rng::AbstractRNG,
::Type{T},
d::Distribution{<:ArrayLikeVariate{0}},
sz::Dims,
) where {T<:Real}
convert_realtype(T, reshape(rand(d, prod(sz)), sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real}
function _flat_powrand(
rng::AbstractRNG,
::Type{T},
d::Distribution{<:ArrayLikeVariate{1}},
sz::Dims,
) where {T<:Real}
convert_realtype(T, reshape(rand(d, prod(sz)), size(d)..., sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N}
function _flat_powrand(
rng::AbstractRNG,
::Type{T},
d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}},
sz::Dims,
) where {T<:Real,N}
convert_realtype(T, reshape(rand(d.dist, prod(sz)), d.dims..., sz...))
end

function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real,N}
function _flat_powrand(
rng::AbstractRNG,
::Type{T},
d::Distribution,
sz::Dims,
) where {T<:Real}
flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(d, sz))))
end

function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N}
function Base.rand(
rng::AbstractRNG,
::Type{T},
m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}},NTuple{N,Base.OneTo{Int}}},
) where {T<:Real,N}
_flat_powrand(rng, T, m.parent.d, map(length, m.axes))
end

function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N}
function Base.rand(
rng::AbstractRNG,
::Type{T},
m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}},NTuple{N,Base.OneTo{Int}}},
) where {T<:Real,M,N}
flat_data = _flat_powrand(rng, T, m.parent.d, map(length, m.axes))
ArrayOfSimilarArrays{T,M,N}(flat_data)
end
Loading