-
Notifications
You must be signed in to change notification settings - Fork 4
Add variate transport #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 60 commits
25fa2e1
a3e1bbf
2015ccd
e86cf2c
2e8e7d6
5019b14
6ee4e3d
2217a42
bb0257d
15952bb
d5635fe
57c83b3
b2817f0
66f5990
de96842
9101015
45fe220
e7fff14
680e8f8
7d9e11e
c6368bb
9372871
70f1e30
147ba2a
0f26323
6ffa92c
101f947
1193e05
d0509e9
9556e31
acde08b
5a16bef
ebb7ddb
7b56f08
23e41c7
6f5b246
f5ebe6d
cbb6873
1c52ed0
7b954a5
5a12523
520562c
10b12dc
bfda82b
62eabb0
a3a7b00
39bf7b0
da7ecc6
75e1fb3
6250b20
9bfa9f9
d4f0246
0512930
14890bb
b32f34d
f12f1b5
b3dfe13
1088fc1
1c2311c
1fb805c
24affe6
35fae34
d673692
3965732
5b5eb7f
e70f294
e85e54f
8f2da10
343d20b
a8faa66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,14 @@ authors = ["Chad Scherrer <[email protected]> and contributors"] | |
version = "0.10.0" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" | ||
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" | ||
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" | ||
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" | ||
|
@@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | |
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" | ||
|
||
[compat] | ||
ChainRulesCore = "1" | ||
ChangesOfVariables = "0.1.3" | ||
Compat = "3.35, 4" | ||
ConstructionBase = "1.3" | ||
DensityInterface = "0.4" | ||
FillArrays = "0.12, 0.13" | ||
IfElse = "0.1" | ||
InverseFunctions = "0.1.7" | ||
IrrationalConstants = "0.1" | ||
LogExpFunctions = "0.3" | ||
LogarithmicNumbers = "1" | ||
|
@@ -42,6 +48,7 @@ julia = "1.3" | |
|
||
[extras] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" | ||
|
||
[targets] | ||
test = ["Aqua"] | ||
test = ["Aqua", "ChainRulesTestUtils"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,99 @@ function params(::AbstractTransformedMeasure) end | |
function paramnames(::AbstractTransformedMeasure) end | ||
|
||
function parent(::AbstractTransformedMeasure) end | ||
|
||
|
||
abstract type TransformVolCorr end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a reason to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
struct WithVolCorr <: TransformVolCorr end | ||
struct NoVolCorr <: TransformVolCorr end | ||
|
||
|
||
export PushforwardMeasure | ||
|
||
""" | ||
struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward | ||
f :: FF | ||
inv_f :: IF | ||
origin :: MU | ||
volcorr :: VC | ||
end | ||
""" | ||
struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward | ||
f::FF | ||
inv_f::IF | ||
origin::M | ||
volcorr::VC | ||
end | ||
|
||
gettransform(ν::PushforwardMeasure) = ν.f | ||
parent(ν::PushforwardMeasure) = ν.origin | ||
|
||
|
||
function Pretty.tile(ν::PushforwardMeasure) | ||
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure) | ||
end | ||
|
||
|
||
@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M} | ||
x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) | ||
logd_orig = logdensity_def(ν.origin, x_orig) | ||
logd = float(logd_orig + inv_ladj) | ||
neginf = oftype(logd, -Inf) | ||
return ifelse( | ||
# Zero density wins against infinite volume: | ||
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) || | ||
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? | ||
# Return constant -Inf to prevent problems with ForwardDiff: | ||
(isfinite(logd_orig) && (inv_ladj == -Inf)), | ||
neginf, | ||
logd | ||
) | ||
end | ||
|
||
@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M} | ||
x_orig = to_origin(ν, y) | ||
return logdensity_def(ν.origin, x_orig) | ||
end | ||
|
||
|
||
insupport(ν::PushforwardMeasure, y) = insupport(vartransform_origin(ν), to_origin(ν, y)) | ||
|
||
testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origin(ν))) | ||
|
||
@inline function basemeasure(ν::PushforwardMeasure) | ||
PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν)), NoVolCorr()) | ||
end | ||
|
||
|
||
_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}() | ||
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof | ||
|
||
# Assume that DOF are preserved if with_logabsdet_jacobian is functional: | ||
@inline function getdof(ν::MU) where {MU<:PushforwardMeasure} | ||
T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)}) | ||
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f), T}) | ||
_pushfwd_dof(MU, R, getdof(ν.origin)) | ||
end | ||
|
||
# Bypass `checked_var`, would require potentially costly transformation: | ||
@inline checked_var(::PushforwardMeasure, x) = x | ||
|
||
|
||
@inline vartransform_origin(ν::PushforwardMeasure) = ν.origin | ||
@inline to_origin(ν::PushforwardMeasure, x) = ν.inv_f(x) | ||
@inline from_origin(ν::PushforwardMeasure, y) = ν.f(y) | ||
|
||
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T | ||
return from_origin(ν, rand(rng, T, vartransform_origin(ν))) | ||
end | ||
|
||
|
||
export pushfwd | ||
|
||
""" | ||
pushfwd(f, μ, volcorr = WithVolCorr()) | ||
|
||
Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) | ||
from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. | ||
""" | ||
pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
MeasureBase.NoDOF{MU} | ||
|
||
Indicates that there is no way to compute degrees of freedom of a measure | ||
of type `MU` with the given information, e.g. because the DOF are not | ||
a global property of the measure. | ||
""" | ||
struct NoDOF{MU} end | ||
|
||
|
||
""" | ||
getdof(μ) | ||
|
||
Returns the effective number of degrees of freedom of variates of | ||
measure `μ`. | ||
|
||
The effective NDOF my differ from the length of the variates. For example, | ||
the effective NDOF for a Dirichlet distribution with variates of length `n` | ||
is `n - 1`. | ||
|
||
Also see [`check_dof`](@ref). | ||
""" | ||
function getdof end | ||
|
||
# Prevent infinite recursion: | ||
@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU} | ||
@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base) | ||
|
||
@inline getdof(μ::MU) where MU = _default_getdof(MU, basemeasure(μ)) | ||
|
||
|
||
""" | ||
MeasureBase.check_dof(ν, μ)::Nothing | ||
|
||
Check if `ν` and `μ` have the same effective number of degrees of freedom | ||
according to [`MeasureBase.getdof`](@ref). | ||
""" | ||
function check_dof end | ||
|
||
function check_dof(ν, μ) | ||
n_ν = getdof(ν) | ||
n_μ = getdof(μ) | ||
if n_ν != n_μ | ||
throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF")) | ||
end | ||
return nothing | ||
end | ||
|
||
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() | ||
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback | ||
|
||
|
||
""" | ||
MeasureBase.NoVarCheck{MU,T} | ||
|
||
Indicates that there is no way to check of a values of type `T` are | ||
variate of measures of type `MU`. | ||
""" | ||
struct NoVarCheck{MU,T} end | ||
|
||
|
||
""" | ||
MeasureBase.checked_var(μ::MU, x::T)::T | ||
|
||
Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not, | ||
return `NoVarCheck{MU,T}()` if not check can be performed. | ||
""" | ||
function checked_var end | ||
|
||
# Prevent infinite recursion: | ||
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T} | ||
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x) | ||
|
||
@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) | ||
|
||
_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ | ||
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
inssupport(m, x) | ||
insupport(m) | ||
|
||
`insupport(m,x)` computes whether `x` is in the support of `m`. | ||
|
||
`insupport(m)` returns a function, and satisfies | ||
|
||
insupport(m)(x) == insupport(m, x) | ||
""" | ||
function insupport end | ||
|
||
|
||
""" | ||
MeasureBase.require_insupport(μ, x)::Nothing | ||
|
||
Checks if `x` is in the support of distribution/measure `μ`, throws an | ||
`ArgumentError` if not. | ||
""" | ||
function require_insupport end | ||
|
||
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() | ||
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) | ||
return require_insupport(μ, x), _require_insupport_pullback | ||
end | ||
|
||
function require_insupport(μ, x) | ||
if !insupport(μ, x) | ||
throw(ArgumentError("x is not within the support of μ")) | ||
end | ||
return nothing | ||
end |
Uh oh!
There was an error while loading. Please reload this page.