-
Notifications
You must be signed in to change notification settings - Fork 4
Add variate transport #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
25fa2e1
Add ChainRulesCore to dependencies
oschulz a3e1bbf
Add InverseFunctions and ChangesOfVariables to deps
oschulz 2015ccd
Add require_insupport
oschulz e86cf2c
Add effndof and require_same_effndof
oschulz 2e8e7d6
Add check_varshape
oschulz 5019b14
Add vartransform
oschulz 6ee4e3d
Remove "measure-like" terminology
oschulz 2217a42
Remove requirement for vartransform to return a Function
oschulz bb0257d
Remove check_varshape
oschulz 15952bb
Rename effndof to getdof
oschulz d5635fe
Separate vartransform and vartransform_def
oschulz 57c83b3
Fix default vartransform_def
oschulz b2817f0
Remove select_vartransform_intermediate
oschulz 66f5990
Fix check_dof
oschulz de96842
Export getdof
oschulz 9101015
Export vartransform
oschulz 45fe220
Implement getdof for measures
oschulz e7fff14
Remove StdNormal
oschulz 680e8f8
FIXUP implement getdof
oschulz 7d9e11e
FIXUP implement getdof
oschulz c6368bb
Add StdLogistic
oschulz 9372871
Implement vartransform_def for StdMeasure
oschulz 70f1e30
FIXUP vartransform_def for StdMeasure
oschulz 147ba2a
Add _vartransform_intermediate
oschulz 0f26323
FIXUP Implement vartransform_def for StdMeasure
oschulz 6ffa92c
Fix insupport for StdLogistic
oschulz 101f947
Fix StdLogistic
oschulz 1193e05
FIXUP StdMeasure vartransform
oschulz d0509e9
Fix check_dof
oschulz 9556e31
Add checked_var
oschulz acde08b
WIP Add vartransform tests
oschulz 5a16bef
FIXUP vartransform tests
oschulz ebb7ddb
Fix rand for StdUniform
oschulz 7b56f08
FIXUP vartransform tests
oschulz 23e41c7
Use checked_var at VarTransformation input stage
oschulz 6f5b246
FIX vartransform tests
oschulz f5ebe6d
Add defaults for check_dof and checked_var
oschulz cbb6873
Add vartransform_origin for WeightedMeasure
oschulz 1c52ed0
Fix deps
oschulz 7b954a5
Fix tests
oschulz 5a12523
WIP Add PushforwardMeasure
oschulz 520562c
WIP improve PushforwardMeasure
oschulz 10b12dc
WIP improve PushforwardMeasure
oschulz bfda82b
WIP improve PushforwardMeasure
oschulz 62eabb0
WIP improve PushforwardMeasure
oschulz a3a7b00
FIX PushforwardMeasure
oschulz 39bf7b0
Allow PushforwardMeasure to bypass checked_var
oschulz da7ecc6
Test PushforwardMeasure
oschulz 75e1fb3
Fix docstring of NoDOF
oschulz 6250b20
Add test_vartransform to Interface
oschulz 9bfa9f9
FIXUP _default_checked_var
oschulz d4f0246
FIXUP vartransform_origin docs and defaults
oschulz 0512930
Run vartransform tests
oschulz 14890bb
Improve vartransform_origin def for WeightedMeasure
oschulz b32f34d
Add vartransform stdmeasure autodim
oschulz f12f1b5
Specialize equality for VarTransformation
oschulz b3dfe13
Don't call check_dof so often
oschulz 1088fc1
Improve checked_var for PowerMeasure
oschulz 1c2311c
Fix check_dof and require_insupport rrules
oschulz 1fb805c
Test getdof
oschulz 24affe6
Document TransformVolCorr
oschulz 35fae34
Fix transform variable naming inconsistencies
oschulz d673692
Specialize gotdof for inferrably empty power measures
oschulz 3965732
Add trafos for Dirac
oschulz 5b5eb7f
Support logdensity calculation on empty power measures
oschulz e70f294
Improve test_vartransform
oschulz e85e54f
Fix and test vartransform for Dirac
oschulz 8f2da10
Rename vartransform to transport_to
oschulz 343d20b
Rename vartransform_origin
oschulz a8faa66
Increase package version to v0.11.0
oschulz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,14 @@ authors = ["Chad Scherrer <[email protected]> and contributors"] | |
version = "0.10.0" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" | ||
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" | ||
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" | ||
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" | ||
|
@@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | |
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" | ||
|
||
[compat] | ||
ChainRulesCore = "1" | ||
ChangesOfVariables = "0.1.3" | ||
Compat = "3.35, 4" | ||
ConstructionBase = "1.3" | ||
DensityInterface = "0.4" | ||
FillArrays = "0.12, 0.13" | ||
IfElse = "0.1" | ||
InverseFunctions = "0.1.7" | ||
IrrationalConstants = "0.1" | ||
LogExpFunctions = "0.3" | ||
LogarithmicNumbers = "1" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
effndof(μ) | ||
|
||
Returns the effective number of degrees of freedom of variates of | ||
measure-like object `μ`. | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The effective NDOF my differ from the length of the variates. For example, | ||
the effective NDOF for a Dirichlet distribution with variates of length `n` | ||
is `n - 1`. | ||
|
||
Also see [`require_same_effndof`](@ref). | ||
""" | ||
function effndof end | ||
|
||
|
||
""" | ||
MeasureBase.require_same_effndof(a, b)::Nothing | ||
|
||
Check if `a` and `b` have the same effective number of degrees of freedom | ||
according to [`MeasureBase.effndof`](@ref). | ||
""" | ||
function require_same_effndof end | ||
|
||
ChainRulesCore.rrule(::typeof(require_same_effndof), a, b) = nothing, _nogradient_pullback2 | ||
|
||
function require_same_effndof(a, b) | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
trg_d_n = effndof(ν) | ||
src_d_n = effndof(μ) | ||
if trg_d_n != src_d_n | ||
throw(ArgumentError("Can't convert to $(typeof(ν).name) with $(trg_d_n) eff. DOF from $(typeof(μ).name) with $(src_d_n) eff. DOF")) | ||
end | ||
return nothing | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
inssupport(m, x) | ||
insupport(m) | ||
|
||
`insupport(m,x)` computes whether `x` is in the support of `m`. | ||
|
||
`insupport(m)` returns a function, and satisfies | ||
|
||
insupport(m)(x) == insupport(m, x) | ||
""" | ||
function insupport end | ||
|
||
|
||
""" | ||
MeasureBase.require_insupport(μ, x)::Nothing | ||
|
||
Checks if `x` is in the support of distribution/measure `μ`, throws an | ||
`ArgumentError` if not. | ||
""" | ||
function require_insupport end | ||
|
||
_check_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() | ||
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) | ||
return require_insupport(μ, x), _check_insupport_pullback | ||
end | ||
|
||
function require_insupport(μ, x::AbstractArray{T,N}) where {T,N} | ||
if !insupport(μ, x) | ||
throw(ArgumentError("x is not within the support of μ")) | ||
end | ||
return nothing | ||
end | ||
|
||
|
||
""" | ||
MeasureBase.check_varshape(μ, x)::Nothing | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Checks if `x` has the correct shape/size for a variate of measure-like object | ||
`μ`, throws an `ArgumentError` if not. | ||
""" | ||
function check_varshape end | ||
|
||
_check_varshape_pullback(ΔΩ) = NoTangent(), ZeroTangent() | ||
ChainRulesCore.rrule(::typeof(check_varshape), μ, x) = check_varshape(μ, x), _check_varshape_pullback |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
""" | ||
struct MeasureBase.NoTransformOrigin{MU} | ||
|
||
Indicates that no (default) pullback measure is available for measures of | ||
type `MU`. | ||
|
||
See [`MeasureBase.vartransform_origin`](@ref). | ||
""" | ||
struct NoTransformOrigin{MU} end | ||
|
||
|
||
""" | ||
MeasureBase.vartransform_origin(μ) | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Default measure to pullback to resp. pushforward from when transforming | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
between `μ` and another measure. | ||
""" | ||
function vartransform_origin end | ||
|
||
vartransform_origin(m::M) where M = NoTransformOrigin{M}() | ||
|
||
|
||
""" | ||
MeasureBase.from_origin(μ, y) | ||
|
||
Push `y` from `MeasureBase.vartransform_origin(μ)` forward to `μ`. | ||
""" | ||
function from_origin end | ||
|
||
from_origin(m::M) where M = NoTransformOrigin{M}() | ||
|
||
|
||
""" | ||
MeasureBase.to_origin(μ, x) | ||
|
||
Pull `x` from `μ` back to `MeasureBase.vartransform_origin(μ)`. | ||
""" | ||
function to_origin end | ||
|
||
to_origin(m::M) where M = NoTransformOrigin{M}() | ||
|
||
|
||
""" | ||
struct MeasureBase.NoVarTransform{NU,MU} end | ||
|
||
Indicates that no transformation from a measure of type `MU` to a measure of | ||
type `NU` could be found. | ||
""" | ||
struct NoVarTransform{NU,MU} end | ||
|
||
|
||
""" | ||
f = vartransform(ν, μ)::Function | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) | ||
`f` that transforms values distributed according to measure-like object `μ` to | ||
values distributed according to a measure-like object `ν`. | ||
|
||
y = vartransform(ν, μ, x) | ||
|
||
Transforms a value `x` distributed according to `μ` to a value `y` distributed | ||
according to `ν`. | ||
|
||
The [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) | ||
from `μ` under `f` is is equivalent to `ν`. | ||
|
||
If terms of random values this implies that `f(rand(μ))` is equivalent to | ||
`rand(ν)` (if `rand(μ)` and `rand(ν)` are supported). | ||
|
||
The resulting function `f` should support | ||
`ChangesOfVariables.with_logabsdet_jacobian(f, x)` if mathematically well-defined, | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
so that densities of `ν` can be derived from densities of `μ` via `f` (using | ||
appropriate base measures). | ||
|
||
Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from | ||
`μ` to `ν` can be found. | ||
|
||
To add transformation rules for a measure-like type `MyMeasure`, specialize | ||
|
||
* `MeasureBase.vartransform(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` | ||
* `MeasureBase.vartransform(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` | ||
|
||
and/or | ||
|
||
* `MeasureBase.vartransform_origin(ν::MyMeasure) = SomeMeasure(...)` | ||
* `MeasureBase.from_origin(μ::MyMeasure, y) = x` | ||
* `MeasureBase.to_origin(μ::MyMeasure, x) = y` | ||
|
||
and ensure `MeasureBase.effndof(μ::MyMeasure)` is defined correctly. | ||
|
||
If no direct transformation rule is available, `vartransform(ν, μ, x)` uses | ||
the following strategy: | ||
|
||
* Evaluate [`vartransform_origin`](@ref) for μ and ν. If both have an origin, | ||
select one as an intermediate measure using | ||
[`select_vartransform_intermediate`](@ref). Try to transform from `μ` to | ||
that intermediate measure and then to `ν` origin(s) of `μ` and/or `ν` if | ||
available. | ||
|
||
* If all else fails, try to transform from μ to a standard multivariate | ||
uniform measure and then to ν. | ||
""" | ||
function vartransform end | ||
|
||
|
||
function _vartransform_with_intermediate(ν, m, μ, x) | ||
x_m = vartransform(m, μ, x) | ||
_vartransform_with_intermediate_step2(ν, m, x_m) | ||
end | ||
|
||
@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform(ν, m, x_m) | ||
@inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m | ||
|
||
function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x) | ||
_vartransform_with_intermediate(ν, StdUniform()^effndof(μ), μ, x) | ||
end | ||
|
||
|
||
# Prevent endless recursion: | ||
_vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() | ||
_vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() | ||
|
||
function vartransform(ν, μ, x) | ||
require_same_effndof(ν, μ) | ||
m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) | ||
_vartransform_with_intermediate(ν, m, μ, x) | ||
end | ||
|
||
vartransform(::Any, ::Any, x::NoTransformOrigin) = x | ||
|
||
|
||
""" | ||
struct VarTransformation <: Function | ||
|
||
Transforms a variate from one measure-like object to a variate of another. | ||
|
||
In general users should not call `VarTransformation` directly, call | ||
[`vartransform`](@ref) instead. | ||
""" | ||
struct VarTransformation{NU,MU} <: Function | ||
ν::NU | ||
μ::MU | ||
|
||
function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} | ||
require_same_effndof(ν, μ) | ||
return new{NU,MU}(ν, μ) | ||
end | ||
|
||
function VarTransformation(ν::NU, μ::MU) where {NU,MU} | ||
require_same_effndof(ν, μ) | ||
return new{NU,MU}(ν, μ) | ||
end | ||
end | ||
|
||
vartransform(ν, μ) = VarTransformation(ν, μ) | ||
|
||
|
||
(f::VarTransformation)(x) = vartransform(f.ν, f.μ, x) | ||
|
||
InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) | ||
|
||
|
||
function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) | ||
y = f(x) | ||
logpdf_src = logdensityof(f.μ, x) | ||
logpdf_trg = logdensityof(f.ν, y) | ||
ladj = logpdf_src - logpdf_trg | ||
# If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe: | ||
fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj | ||
return y, fixed_ladj | ||
end | ||
|
||
|
||
Base.:(∘)(::typeof(identity), f::VarTransformation) = f | ||
Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f | ||
|
||
function Base.:∘(outer::VarTransformation, inner::VarTransformation) | ||
if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν) | ||
throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner.")) | ||
end | ||
return VarTransformation(outer.ν, inner.μ) | ||
end | ||
|
||
|
||
function Base.show(io::IO, f::VarTransformation) | ||
print(io, Base.typename(typeof(f)).name, "(") | ||
show(io, f.ν) | ||
print(io, ", ") | ||
show(io, f.μ) | ||
print(io, ")") | ||
end | ||
|
||
Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) | ||
|
||
|
||
|
||
|
||
|
||
""" | ||
MeasureBase.select_vartransform_intermediate(a, b) | ||
|
||
Selects one of two candidate pullback measures `a, b` to use as an | ||
intermediate in variate transformations. | ||
|
||
See [`MeasureBase.vartransform_intermediate`](@ref). | ||
""" | ||
function select_vartransform_intermediate end | ||
oschulz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
select_vartransform_intermediate(nu, ::NoTransformOrigin) = nu | ||
select_vartransform_intermediate(::NoTransformOrigin, mu) = mu | ||
select_vartransform_intermediate(::NoTransformOrigin, mu::NoTransformOrigin) = mu | ||
|
||
# Ensure forward and inverse transformation use the same intermediate: | ||
@generated function select_vartransform_intermediate(a, b) | ||
return nameof(a) < nameof(b) ? :a : :b | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.