-
Notifications
You must be signed in to change notification settings - Fork 1
Fully implement DistributionMeasure and add var transforms #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 77 commits
Commits
Show all changes
90 commits
Select commit
Hold shift + click to select a range
16af94c
STASH vartransform
oschulz 9225c1a
STASH require_insupport
oschulz 6ea552b
STASH
oschulz 21326f2
Require MeasureBase v0.10
oschulz fb82354
Add ChainRulesCore, ForwardDiff and ForwardDiffPullbacks
oschulz d365698
STASH
oschulz 5db5f62
FIXUP deps
oschulz ea9a6d6
Add Statistics, StatsBase and StatsFuns to deps
oschulz 0f37407
STASH
oschulz b537945
STASH
oschulz f8aafab
STASH
oschulz 7627b05
Add IrrationalConstants and FillArrays to deps
oschulz 806c201
Use Distributions pdf, logpdf, etc.
oschulz 148e819
Implement MeasureBase interface for distributions
oschulz 60c42df
STASH
oschulz 8e94e01
Use Random.rand!
oschulz adf70b3
STASH
oschulz 349fd32
FIXUP no using pdf, logpdf, etc.
oschulz 77313fb
Use FillArrays Ones and Zeros
oschulz c43ece2
Add PDMats to deps
oschulz cef3ce8
STASH stddist
oschulz aef8533
STASH
oschulz 3629814
Using LinearAlgebra.Diagonal
oschulz b8ad6b5
STASH
oschulz 57e450a
STASH
oschulz 4d2854d
STASH
oschulz 91a40c6
STASH
oschulz 29d8d2a
STASH
oschulz 1d61c95
STASH
oschulz 612de3c
Add dependencies
oschulz d639ad2
STASH
oschulz 3f8e342
FIXUP deps
oschulz 0539b0c
FIXUP deps
oschulz 2c019fc
STASH
oschulz e76abe2
STASH
oschulz 7762601
STASH
oschulz c58384e
FIXUP test ad utils
oschulz 2d74e87
Use test Project.toml
oschulz 9307bb2
STASH test vtrafo
oschulz 9e7b9e4
FIXUP test Project.toml
oschulz aa42bdf
STASH
oschulz 3e82bf8
STASH vartransform_def
oschulz 79ce816
FIXUP deps
oschulz dd6ab04
STASH trafotests
oschulz 0e15881
STASH
oschulz 39035ec
STASH deps
oschulz 0f40e4d
STASH
oschulz 4f26564
Add Static to deps
oschulz 11a763b
MeasureBase deps
oschulz 19cc23f
Add StdNormal measure
oschulz 3bc9040
Fixup StdNormal measure
oschulz 0848251
vartransform for StdNormal
oschulz 53ce9a5
FIXUP deps
oschulz ab5fc63
STASH transforms
oschulz 6b339b3
Export StdNormal
oschulz 7cea2d7
FIXUP deps
oschulz a876c50
Fix vartransform for StdNormal
oschulz 3b0ed28
FIXUP getjacobian in tests
oschulz 9395827
STASH vartransform_def for StdNormal
oschulz 60ed6ce
STASH
oschulz a675872
STASH measure interface for dists
oschulz d6814fd
FIXUP
oschulz aa92f70
STASH
oschulz c5edd8b
STASH a lot
oschulz 5495b9d
Test vartransform auto dim-sel
oschulz 21b1bdf
STASH trafos
oschulz 50d184d
STASH vartransform
oschulz 6010987
STASH vartransform tests and dirichlet
oschulz 5e72697
STASH vartrafo tests
oschulz fb45ddc
STASH transform tests
oschulz d2b7306
Fix transform var naming
oschulz 8852754
Rename vartransform to transport_to
oschulz e4f3cdb
Rename vartransform_origin
oschulz ea77b3b
Require MeasureBase v0.11
oschulz 3474573
Add MeasureBase to test deps
oschulz 184bc00
Fix transport tests
oschulz 8c48244
fix some typos
cscherrer 40757e2
Fixes and more tests
oschulz 2c3b691
Fixes and tests
oschulz f68e214
STASH tests
oschulz 30c746c
STASH
oschulz 81a6627
FIX Prefix scale
oschulz e130d25
Fix transport_def for StandardDist
oschulz 2e433c3
Fixes
oschulz 9153836
Fix tests
oschulz a673c26
Adapt to MeasureBase v0.12
oschulz 58a791c
Allow for static-sized StandardDist
oschulz 4b22c6b
Fix product transport
oschulz 40077dd
Increase package version to v0.2.0
oschulz 319f97d
Support RN integral notation for distributions
oschulz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,28 +4,43 @@ authors = ["Chad Scherrer <[email protected]> and contributors"] | |
version = "0.1.0" | ||
|
||
[deps] | ||
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" | ||
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" | ||
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
ForwardDiffPullbacks = "450a3b6d-2448-4ee1-8e34-e4eb8713b605" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" | ||
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" | ||
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
|
||
[compat] | ||
ArgCheck = "1, 2" | ||
ArraysOfArrays = "0.5" | ||
ChainRulesCore = "1" | ||
ChangesOfVariables = "0.1" | ||
DensityInterface = "0.4" | ||
Distributions = "0.25" | ||
FillArrays = "0.12, 0.13" | ||
ForwardDiff = "0.9, 0.10" | ||
ForwardDiffPullbacks = "0.2" | ||
Functors = "0.2" | ||
InverseFunctions = "0.1" | ||
MeasureBase = "0.9" | ||
IrrationalConstants = "0.1" | ||
MeasureBase = "0.11" | ||
PDMats = "0.11" | ||
Static = "0.5, 0.6" | ||
StatsBase = "0.32, 0.33" | ||
StatsFuns = "0.9, 1" | ||
julia = "1.6" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,74 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
module DistributionMeasures | ||
|
||
using LinearAlgebra: Diagonal, dot, cholesky | ||
|
||
import Random | ||
using Random: AbstractRNG | ||
using Random: AbstractRNG, rand! | ||
|
||
import DensityInterface | ||
using DensityInterface: logdensityof | ||
|
||
import MeasureBase | ||
using MeasureBase: AbstractMeasure, Lebesgue, Counting | ||
using MeasureBase: PowerMeasure | ||
using MeasureBase: AbstractMeasure, Lebesgue, Counting, ℝ | ||
using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic | ||
using MeasureBase: PowerMeasure, WeightedMeasure | ||
using MeasureBase: getdof, checked_var | ||
using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin | ||
using MeasureBase: NoTransformOrigin, NoVarTransform | ||
|
||
import Distributions | ||
using Distributions: Distribution, VariateForm, ValueSupport | ||
using Distributions: ArrayLikeVariate, Continuous, Discrete | ||
using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution | ||
using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete | ||
using Distributions: Uniform, Exponential, Logistic, Normal | ||
using Distributions: MvNormal, Beta, Dirichlet | ||
using Distributions: ReshapedDistribution | ||
|
||
import Statistics | ||
import StatsBase | ||
import StatsFuns | ||
import PDMats | ||
|
||
using IrrationalConstants: log2π, invsqrt2π | ||
|
||
using Static: True, False, StaticInt, static | ||
using FillArrays: Fill, Ones, Zeros | ||
|
||
import ChainRulesCore | ||
using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk | ||
|
||
import ForwardDiff | ||
using ForwardDiffPullbacks: fwddiff | ||
|
||
import Functors | ||
using Functors: fmap | ||
|
||
using ArgCheck: @argcheck | ||
|
||
using ArraysOfArrays: ArrayOfSimilarArrays, flatview | ||
|
||
const MeasureLike = Union{AbstractMeasure,Distribution} | ||
export MeasureLike | ||
|
||
include("utils.jl") | ||
include("autodiff_utils.jl") | ||
include("measure_interface.jl") | ||
include("stdnormal_measure.jl") | ||
include("standard_dist.jl") | ||
include("standard_uniform.jl") | ||
include("standard_normal.jl") | ||
include("distribution_measure.jl") | ||
include("dist_vartransform.jl") | ||
include("univariate.jl") | ||
include("standardmv.jl") | ||
include("product.jl") | ||
include("reshaped.jl") | ||
include("dirichlet.jl") | ||
|
||
|
||
const MeasureLike = Union{AbstractMeasure,Distribution} | ||
|
||
export MeasureLike, DistributionMeasure | ||
export StdNormal | ||
export DistributionMeasure | ||
export StandardDist | ||
|
||
|
||
end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
@inline _adignore_call(f) = f() | ||
@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent()) | ||
ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback | ||
|
||
macro _adignore(expr) | ||
:(_adignore_call(() -> $(esc(expr)))) | ||
end | ||
|
||
|
||
function _pushfront(v::AbstractVector, x) | ||
T = promote_type(eltype(v), typeof(x)) | ||
r = similar(v, T, length(eachindex(v)) + 1) | ||
r[firstindex(r)] = x | ||
r[firstindex(r)+1:lastindex(r)] = v | ||
r | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x) | ||
result = _pushfront(v, x) | ||
function _pushfront_pullback(thunked_ΔΩ) | ||
ΔΩ = unthunk(thunked_ΔΩ) | ||
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)]) | ||
end | ||
return result, _pushfront_pullback | ||
end | ||
|
||
|
||
function _pushback(v::AbstractVector, x) | ||
T = promote_type(eltype(v), typeof(x)) | ||
r = similar(v, T, length(eachindex(v)) + 1) | ||
r[lastindex(r)] = x | ||
r[firstindex(r):lastindex(r)-1] = v | ||
r | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x) | ||
result = _pushback(v, x) | ||
function _pushback_pullback(thunked_ΔΩ) | ||
ΔΩ = unthunk(thunked_ΔΩ) | ||
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)]) | ||
end | ||
return result, _pushback_pullback | ||
end | ||
|
||
|
||
_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] | ||
|
||
_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] | ||
|
||
|
||
_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) | ||
|
||
function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector) | ||
result = _rev_cumsum(xs) | ||
function _rev_cumsum_pullback(ΔΩ) | ||
∂xs = @thunk cumsum(unthunk(ΔΩ)) | ||
(NoTangent(), ∂xs) | ||
end | ||
return result, _rev_cumsum_pullback | ||
end | ||
|
||
|
||
# Equivalent to `cumprod(xs)``: | ||
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) | ||
|
||
function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector) | ||
result = _exp_cumsum_log(xs) | ||
function _exp_cumsum_log_pullback(ΔΩ) | ||
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ)) | ||
(NoTangent(), ∂xs) | ||
end | ||
return result, _exp_cumsum_log_pullback | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
MeasureBase.getdof(d::Dirichlet) = length(d) - 1 | ||
|
||
MeasureBase.transport_origin(ν::Dirichlet) = StdUniform()^getdof(ν) | ||
|
||
|
||
function _dirichlet_beta_trafo(α::Real, β::Real, x::Real) | ||
R = float(promote_type(typeof(α), typeof(β), typeof(x))) | ||
convert(R, transport_def(Beta(α, β), StdUniform(), x))::R | ||
end | ||
|
||
_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) | ||
|
||
function MeasureBase.from_origin(ν::Dirichlet, x) | ||
# See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", | ||
# https://arxiv.org/abs/1010.3436 | ||
|
||
# Sanity check (TODO - remove?): | ||
@_adignore @argcheck length(ν) == length(x) + 1 | ||
|
||
αs = _dropfront(_rev_cumsum(ν.alpha)) | ||
βs = _dropback(ν.alpha) | ||
beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x) | ||
beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) | ||
beta_v_ext = _pushback(beta_v, 0) | ||
fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) | ||
end | ||
|
||
# ToDo: MeasureBase.to_origin(ν::Dirichlet, y) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
const _AnyStdUniform = Union{StandardUniform, Uniform} | ||
const _AnyStdNormal = Union{StandardNormal, Normal} | ||
|
||
const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal} | ||
|
||
_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform | ||
_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal | ||
|
||
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M() | ||
_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof) | ||
_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ)) | ||
|
||
MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ) | ||
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
@inline MeasureBase.logdensity_def(d::Distribution, x) = DensityInterface.logdensityof(d, x) | ||
@inline MeasureBase.unsafe_logdensityof(d::Distribution, x) = DensityInterface.logdensityof(d, x) | ||
|
||
@inline MeasureBase.insupport(d::Distribution, x) = Distributions.insupport(d, x) | ||
|
||
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() | ||
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(d) | ||
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate{0},<:Discrete}) = Counting() | ||
@inline MeasureBase.basemeasure(d::Distribution{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(d) | ||
|
||
@inline MeasureBase.insupport(d::Distribution, x) = Distributions.insupport(d, x) | ||
@inline MeasureBase.paramnames(d::Distribution) = propertynames(d) | ||
@inline MeasureBase.params(d::Distribution) = NamedTuple{MeasureBase.paramnames(m.d)}(Distributions.params(d)) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
const _StdPowMeasure1 = PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}} | ||
const _UniformProductDist1 = Distributions.Product{Continuous,<:Uniform,<:AbstractVector{<:Uniform}} | ||
|
||
|
||
MeasureBase.getdof(d::_UniformProductDist1) = length(d) | ||
|
||
|
||
function _product_dist_trafo_impl(νs, μs, x) | ||
fwddiff(transport_def).(νs, μs, x) | ||
end | ||
|
||
function MeasureBase.transport_def(ν::_StdPowMeasure1, μ::_UniformProductDist1, x) | ||
_product_dist_trafo_impl((ν.parent,), μ.v, x) | ||
end | ||
|
||
function MeasureBase.transport_def(ν::_UniformProductDist1, μ::_StdPowMeasure1, x) | ||
_product_dist_trafo_impl(ν.v, (μ.parent,), x) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# This file is a part of DistributionMeasures.jl, licensed under the MIT License (MIT). | ||
|
||
MeasureBase.getdof(μ::ReshapedDistribution) = MeasureBase.getdof(μ.dist) | ||
|
||
MeasureBase.transport_origin(μ::ReshapedDistribution) = μ.dist | ||
|
||
MeasureBase.to_origin(ν::ReshapedDistribution, y) = reshape(y, size(ν.dist)) | ||
|
||
MeasureBase.from_origin(ν::ReshapedDistribution, x) = reshape(x, ν.dims) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.