Skip to content

Add variate transport #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
25fa2e1
Add ChainRulesCore to dependencies
oschulz Jun 15, 2022
a3e1bbf
Add InverseFunctions and ChangesOfVariables to deps
oschulz Jun 15, 2022
2015ccd
Add require_insupport
oschulz Jun 15, 2022
e86cf2c
Add effndof and require_same_effndof
oschulz Jun 15, 2022
2e8e7d6
Add check_varshape
oschulz Jun 15, 2022
5019b14
Add vartransform
oschulz Jun 15, 2022
6ee4e3d
Remove "measure-like" terminology
oschulz Jun 15, 2022
2217a42
Remove requirement for vartransform to return a Function
oschulz Jun 15, 2022
bb0257d
Remove check_varshape
oschulz Jun 15, 2022
15952bb
Rename effndof to getdof
oschulz Jun 15, 2022
d5635fe
Separate vartransform and vartransform_def
oschulz Jun 15, 2022
57c83b3
Fix default vartransform_def
oschulz Jun 15, 2022
b2817f0
Remove select_vartransform_intermediate
oschulz Jun 15, 2022
66f5990
Fix check_dof
oschulz Jun 16, 2022
de96842
Export getdof
oschulz Jun 16, 2022
9101015
Export vartransform
oschulz Jun 16, 2022
45fe220
Implement getdof for measures
oschulz Jun 16, 2022
e7fff14
Remove StdNormal
oschulz Jun 16, 2022
680e8f8
FIXUP implement getdof
oschulz Jun 16, 2022
7d9e11e
FIXUP implement getdof
oschulz Jun 16, 2022
c6368bb
Add StdLogistic
oschulz Jun 16, 2022
9372871
Implement vartransform_def for StdMeasure
oschulz Jun 16, 2022
70f1e30
FIXUP vartransform_def for StdMeasure
oschulz Jun 16, 2022
147ba2a
Add _vartransform_intermediate
oschulz Jun 16, 2022
0f26323
FIXUP Implement vartransform_def for StdMeasure
oschulz Jun 16, 2022
6ffa92c
Fix insupport for StdLogistic
oschulz Jun 16, 2022
101f947
Fix StdLogistic
oschulz Jun 16, 2022
1193e05
FIXUP StdMeasure vartransform
oschulz Jun 16, 2022
d0509e9
Fix check_dof
oschulz Jun 16, 2022
9556e31
Add checked_var
oschulz Jun 16, 2022
acde08b
WIP Add vartransform tests
oschulz Jun 16, 2022
5a16bef
FIXUP vartransform tests
oschulz Jun 16, 2022
ebb7ddb
Fix rand for StdUniform
oschulz Jun 16, 2022
7b56f08
FIXUP vartransform tests
oschulz Jun 16, 2022
23e41c7
Use checked_var at VarTransformation input stage
oschulz Jun 16, 2022
6f5b246
FIX vartransform tests
oschulz Jun 16, 2022
f5ebe6d
Add defaults for check_dof and checked_var
oschulz Jun 16, 2022
cbb6873
Add vartransform_origin for WeightedMeasure
oschulz Jun 16, 2022
1c52ed0
Fix deps
oschulz Jun 16, 2022
7b954a5
Fix tests
oschulz Jun 16, 2022
5a12523
WIP Add PushforwardMeasure
oschulz Jun 16, 2022
520562c
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
10b12dc
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
bfda82b
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
62eabb0
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
a3a7b00
FIX PushforwardMeasure
oschulz Jun 16, 2022
39bf7b0
Allow PushforwardMeasure to bypass checked_var
oschulz Jun 17, 2022
da7ecc6
Test PushforwardMeasure
oschulz Jun 17, 2022
75e1fb3
Fix docstring of NoDOF
oschulz Jun 17, 2022
6250b20
Add test_vartransform to Interface
oschulz Jun 17, 2022
9bfa9f9
FIXUP _default_checked_var
oschulz Jun 17, 2022
d4f0246
FIXUP vartransform_origin docs and defaults
oschulz Jun 17, 2022
0512930
Run vartransform tests
oschulz Jun 17, 2022
14890bb
Improve vartransform_origin def for WeightedMeasure
oschulz Jun 17, 2022
b32f34d
Add vartransform stdmeasure autodim
oschulz Jun 17, 2022
f12f1b5
Specialize equality for VarTransformation
oschulz Jun 18, 2022
b3dfe13
Don't call check_dof so often
oschulz Jun 18, 2022
1088fc1
Improve checked_var for PowerMeasure
oschulz Jun 18, 2022
1c2311c
Fix check_dof and require_insupport rrules
oschulz Jun 18, 2022
1fb805c
Test getdof
oschulz Jun 18, 2022
24affe6
Document TransformVolCorr
oschulz Jun 18, 2022
35fae34
Fix transform variable naming inconsistencies
oschulz Jun 18, 2022
d673692
Specialize gotdof for inferrably empty power measures
oschulz Jun 18, 2022
3965732
Add trafos for Dirac
oschulz Jun 18, 2022
5b5eb7f
Support logdensity calculation on empty power measures
oschulz Jun 19, 2022
e70f294
Improve test_vartransform
oschulz Jun 19, 2022
e85e54f
Fix and test vartransform for Dirac
oschulz Jun 19, 2022
8f2da10
Rename vartransform to transport_to
oschulz Jun 19, 2022
343d20b
Rename vartransform_origin
oschulz Jun 19, 2022
a8faa66
Increase package version to v0.11.0
oschulz Jun 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.10.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
Compat = "3.35, 4"
ConstructionBase = "1.3"
DensityInterface = "0.4"
FillArrays = "0.12, 0.13"
IfElse = "0.1"
InverseFunctions = "0.1.7"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogarithmicNumbers = "1"
Expand Down
21 changes: 8 additions & 13 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ import DensityInterface: densityof
import DensityInterface: DensityKind
using DensityInterface

using InverseFunctions
using ChangesOfVariables

import Base.iterate
import ConstructionBase
using ConstructionBase: constructorof

using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
using FillArrays
using Static

Expand All @@ -32,21 +36,10 @@ export logdensity_def
export basemeasure
export basekernel
export productmeasure

"""
inssupport(m, x)
insupport(m)

`insupport(m,x)` computes whether `x` is in the support of `m`.

`insupport(m)` returns a function, and satisfies

insupport(m)(x) == insupport(m, x)
"""
function insupport end

export insupport

include("insupport.jl")

abstract type AbstractMeasure end

using Static: @constprop
Expand Down Expand Up @@ -94,6 +87,8 @@ using Compat

using IrrationalConstants

include("effndof.jl")
include("vartransform.jl")
include("schema.jl")
include("splat.jl")
include("proxies.jl")
Expand Down
33 changes: 33 additions & 0 deletions src/effndof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
effndof(μ)

Returns the effective number of degrees of freedom of variates of
measure-like object `μ`.

The effective NDOF my differ from the length of the variates. For example,
the effective NDOF for a Dirichlet distribution with variates of length `n`
is `n - 1`.

Also see [`require_same_effndof`](@ref).
"""
function effndof end


"""
MeasureBase.require_same_effndof(a, b)::Nothing

Check if `a` and `b` have the same effective number of degrees of freedom
according to [`MeasureBase.effndof`](@ref).
"""
function require_same_effndof end

ChainRulesCore.rrule(::typeof(require_same_effndof), a, b) = nothing, _nogradient_pullback2

function require_same_effndof(a, b)
trg_d_n = effndof(ν)
src_d_n = effndof(μ)
if trg_d_n != src_d_n
throw(ArgumentError("Can't convert to $(typeof(ν).name) with $(trg_d_n) eff. DOF from $(typeof(μ).name) with $(src_d_n) eff. DOF"))
end
return nothing
end
44 changes: 44 additions & 0 deletions src/insupport.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
inssupport(m, x)
insupport(m)

`insupport(m,x)` computes whether `x` is in the support of `m`.

`insupport(m)` returns a function, and satisfies

insupport(m)(x) == insupport(m, x)
"""
function insupport end


"""
MeasureBase.require_insupport(μ, x)::Nothing

Checks if `x` is in the support of distribution/measure `μ`, throws an
`ArgumentError` if not.
"""
function require_insupport end

_check_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _check_insupport_pullback
end

function require_insupport(μ, x::AbstractArray{T,N}) where {T,N}
if !insupport(μ, x)
throw(ArgumentError("x is not within the support of μ"))
end
return nothing
end


"""
MeasureBase.check_varshape(μ, x)::Nothing

Checks if `x` has the correct shape/size for a variate of measure-like object
`μ`, throws an `ArgumentError` if not.
"""
function check_varshape end

_check_varshape_pullback(ΔΩ) = NoTangent(), ZeroTangent()
ChainRulesCore.rrule(::typeof(check_varshape), μ, x) = check_varshape(μ, x), _check_varshape_pullback
216 changes: 216 additions & 0 deletions src/vartransform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
struct MeasureBase.NoTransformOrigin{MU}

Indicates that no (default) pullback measure is available for measures of
type `MU`.

See [`MeasureBase.vartransform_origin`](@ref).
"""
struct NoTransformOrigin{MU} end


"""
MeasureBase.vartransform_origin(μ)

Default measure to pullback to resp. pushforward from when transforming
between `μ` and another measure.
"""
function vartransform_origin end

vartransform_origin(m::M) where M = NoTransformOrigin{M}()


"""
MeasureBase.from_origin(μ, y)

Push `y` from `MeasureBase.vartransform_origin(μ)` forward to `μ`.
"""
function from_origin end

from_origin(m::M) where M = NoTransformOrigin{M}()


"""
MeasureBase.to_origin(μ, x)

Pull `x` from `μ` back to `MeasureBase.vartransform_origin(μ)`.
"""
function to_origin end

to_origin(m::M) where M = NoTransformOrigin{M}()


"""
struct MeasureBase.NoVarTransform{NU,MU} end

Indicates that no transformation from a measure of type `MU` to a measure of
type `NU` could be found.
"""
struct NoVarTransform{NU,MU} end


"""
f = vartransform(ν, μ)::Function

Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function)
`f` that transforms values distributed according to measure-like object `μ` to
values distributed according to a measure-like object `ν`.

y = vartransform(ν, μ, x)

Transforms a value `x` distributed according to `μ` to a value `y` distributed
according to `ν`.

The [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
from `μ` under `f` is is equivalent to `ν`.

If terms of random values this implies that `f(rand(μ))` is equivalent to
`rand(ν)` (if `rand(μ)` and `rand(ν)` are supported).

The resulting function `f` should support
`ChangesOfVariables.with_logabsdet_jacobian(f, x)` if mathematically well-defined,
so that densities of `ν` can be derived from densities of `μ` via `f` (using
appropriate base measures).

Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from
`μ` to `ν` can be found.

To add transformation rules for a measure-like type `MyMeasure`, specialize

* `MeasureBase.vartransform(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...`
* `MeasureBase.vartransform(ν::MyMeasure, μ::SomeStdMeasure, x) = ...`

and/or

* `MeasureBase.vartransform_origin(ν::MyMeasure) = SomeMeasure(...)`
* `MeasureBase.from_origin(μ::MyMeasure, y) = x`
* `MeasureBase.to_origin(μ::MyMeasure, x) = y`

and ensure `MeasureBase.effndof(μ::MyMeasure)` is defined correctly.

If no direct transformation rule is available, `vartransform(ν, μ, x)` uses
the following strategy:

* Evaluate [`vartransform_origin`](@ref) for μ and ν. If both have an origin,
select one as an intermediate measure using
[`select_vartransform_intermediate`](@ref). Try to transform from `μ` to
that intermediate measure and then to `ν` origin(s) of `μ` and/or `ν` if
available.

* If all else fails, try to transform from μ to a standard multivariate
uniform measure and then to ν.
"""
function vartransform end


function _vartransform_with_intermediate(ν, m, μ, x)
x_m = vartransform(m, μ, x)
_vartransform_with_intermediate_step2(ν, m, x_m)
end

@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform(ν, m, x_m)
@inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m

function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x)
_vartransform_with_intermediate(ν, StdUniform()^effndof(μ), μ, x)
end


# Prevent endless recursion:
_vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}()
_vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}()

function vartransform(ν, μ, x)
require_same_effndof(ν, μ)
m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ))
_vartransform_with_intermediate(ν, m, μ, x)
end

vartransform(::Any, ::Any, x::NoTransformOrigin) = x


"""
struct VarTransformation <: Function

Transforms a variate from one measure-like object to a variate of another.

In general users should not call `VarTransformation` directly, call
[`vartransform`](@ref) instead.
"""
struct VarTransformation{NU,MU} <: Function
ν::NU
μ::MU

function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU}
require_same_effndof(ν, μ)
return new{NU,MU}(ν, μ)
end

function VarTransformation(ν::NU, μ::MU) where {NU,MU}
require_same_effndof(ν, μ)
return new{NU,MU}(ν, μ)
end
end

vartransform(ν, μ) = VarTransformation(ν, μ)


(f::VarTransformation)(x) = vartransform(f.ν, f.μ, x)

InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν)


function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x)
y = f(x)
logpdf_src = logdensityof(f.μ, x)
logpdf_trg = logdensityof(f.ν, y)
ladj = logpdf_src - logpdf_trg
# If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe:
fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj
return y, fixed_ladj
end


Base.:(∘)(::typeof(identity), f::VarTransformation) = f
Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f

function Base.:∘(outer::VarTransformation, inner::VarTransformation)
if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν)
throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner."))
end
return VarTransformation(outer.ν, inner.μ)
end


function Base.show(io::IO, f::VarTransformation)
print(io, Base.typename(typeof(f)).name, "(")
show(io, f.ν)
print(io, ", ")
show(io, f.μ)
print(io, ")")
end

Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f)





"""
MeasureBase.select_vartransform_intermediate(a, b)

Selects one of two candidate pullback measures `a, b` to use as an
intermediate in variate transformations.

See [`MeasureBase.vartransform_intermediate`](@ref).
"""
function select_vartransform_intermediate end

select_vartransform_intermediate(nu, ::NoTransformOrigin) = nu
select_vartransform_intermediate(::NoTransformOrigin, mu) = mu
select_vartransform_intermediate(::NoTransformOrigin, mu::NoTransformOrigin) = mu

# Ensure forward and inverse transformation use the same intermediate:
@generated function select_vartransform_intermediate(a, b)
return nameof(a) < nameof(b) ? :a : :b
end