Skip to content

Commit cfa0f3b

Browse files
committed
Improve pushforward implementation
Rename TransformVolCorr and subtypes (with backward compatibility).
1 parent a03d7ea commit cfa0f3b

File tree

4 files changed

+179
-79
lines changed

4 files changed

+179
-79
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using DensityInterface: FuncDensity, LogFuncDensity
2121
using DensityInterface
2222

2323
using InverseFunctions
24+
using InverseFunctions: FunctionWithInverse
2425
using ChangesOfVariables
2526

2627
import Base.iterate
Lines changed: 160 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
1-
# TODO: Compare with ChangesOfVariables.jl
1+
"""
2+
abstract type PushFwdStyle
3+
4+
Provides the behavior of a measure's [`rootmeasure`](@ref) under a
5+
pushforward. Either [`AdaptRootMeasure()`](@ref) or
6+
[`PushfwdRootMeasure()`](@ref)
7+
"""
8+
abstract type PushFwdStyle end
9+
export PushFwdStyle
10+
11+
const TransformVolCorr = PushFwdStyle
12+
13+
"""
14+
AdaptRootMeasure()
215
3-
using InverseFunctions: FunctionWithInverse
16+
Indicates that when applying a pushforward to a measure, it's
17+
[`rootmeasure`](@ref) not not be pushed forward. Instead, the root measure
18+
should be kept just "reshaped" to the new measurable space if necessary.
19+
20+
Density calculations for pushforward measures constructed with
21+
`AdaptRootMeasure()` will take take the volume element of variate
22+
transform (typically via the log-abs-det-Jacobian of the transform) into
23+
account.
24+
"""
25+
struct AdaptRootMeasure <: TransformVolCorr end
26+
export AdaptRootMeasure
27+
28+
const WithVolCorr = AdaptRootMeasure
29+
30+
"""
31+
PushfwdRootMeasure()
32+
33+
Indicates than when applying a pushforward to a measure, it's
34+
[`rootmeasure`](@ref) should be pushed forward with the same function.
35+
36+
Density calculations for pushforward measures constructed with
37+
`PushfwdRootMeasure()` will ignore the volume element of the variate
38+
transform.
39+
"""
40+
struct PushfwdRootMeasure <: TransformVolCorr end
41+
export PushfwdRootMeasure
42+
43+
const NoVolCorr = PushfwdRootMeasure
444

545
abstract type AbstractTransformedMeasure <: AbstractMeasure end
646

@@ -19,23 +59,37 @@ function parent(::AbstractTransformedMeasure) end
1959
export PushforwardMeasure
2060

2161
"""
22-
struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward
62+
struct PushforwardMeasure{F,I,M,S<:PushFwdStyle} <: AbstractPushforward
2363
f :: F
2464
finv :: I
2565
origin :: M
26-
volcorr :: VC
66+
style :: S
2767
end
2868
2969
Users should not call `PushforwardMeasure` directly. Instead call or add
3070
methods to `pushfwd`.
3171
"""
32-
struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward
72+
struct PushforwardMeasure{F,I,M,S<:PushFwdStyle} <: AbstractPushforward
3373
f::F
3474
finv::I
3575
origin::M
36-
volcorr::VC
76+
style::S
77+
78+
function PushforwardMeasure{F,I,M,S}(f::F, finv::I, origin::M, style::S) where {F,I,M,S<:PushFwdStyle}
79+
new{F,I,M,S}(f, finv, origin, style)
80+
end
81+
82+
function PushforwardMeasure(f, finv, origin::M, style::S) where {M,S<:PushFwdStyle}
83+
new{Core.Typeof(f),Core.Typeof(finv),M,S}(f, finv, origin, style)
84+
end
3785
end
3886

87+
const _NonBijectivePusfwdMeasure{M<:PushforwardMeasure,S<:PushFwdStyle} = Union{
88+
PushforwardMeasure{<:Any,<:NoInverse,M,S},
89+
PushforwardMeasure{<:NoInverse,<:Any,M,S},
90+
PushforwardMeasure{<:NoInverse,<:NoInverse,M,S},
91+
}
92+
3993
gettransform::PushforwardMeasure) = ν.f
4094
parent::PushforwardMeasure) = ν.origin
4195

@@ -45,55 +99,94 @@ end
4599

46100
# TODO: THIS IS ALMOST CERTAINLY WRONG
47101
# @inline function logdensity_rel(
48-
# ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr},
49-
# β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr},
102+
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
103+
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
50104
# y,
51105
# ) where {FF1,IF1,M1,FF2,IF2,M2}
52106
# x = β.inv_f(y)
53107
# f = ν.inv_f ∘ β.f
54108
# inv_f = β.inv_f ∘ ν.f
55-
# logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr()), β.origin, x)
109+
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
56110
# end
57111

112+
# TODO: Would profit from custom pullback:
113+
function _combine_logd_with_ladj(logd_orig::Real, ladj::Real)
114+
logd_result = logd_orig + ladj
115+
R = typeof(logd_result)
116+
117+
if isnan(logd_result) && isneginf(logd_orig) && isposinf(ladj)
118+
# Zero μ wins against infinite volume:
119+
R(-Inf)::R
120+
elseif isfinite(logd_orig) && isneginf(ladj)
121+
# Maybe also for isneginf(logd_orig) && isfinite(ladj) ?
122+
# Return constant -Inf to prevent problems with ForwardDiff:
123+
#R(-Inf)
124+
near_neg_inf(R)::R # Avoids AdvancedHMC warnings
125+
else
126+
logd_result::R
127+
end
128+
end
129+
130+
function logdensityof(
131+
@nospecialize::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
132+
@nospecialize(v::Any)
133+
) where {M}
134+
throw(
135+
ArgumentError(
136+
"Can't calculate densities for non-bijective pushforward measure $(nameof(M))",
137+
),
138+
)
139+
end
140+
141+
function logdensityof(
142+
@nospecialize::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
143+
@nospecialize(v::Any)
144+
) where {M}
145+
throw(
146+
ArgumentError(
147+
"Can't calculate densities for non-bijective pushforward measure $(nameof(M))",
148+
),
149+
)
150+
end
151+
58152
for func in [:logdensityof, :logdensity_def]
59-
@eval @inline function $func::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
60-
f = ν.f
61-
finv = ν.finv
62-
x_orig, inv_ladj = with_logabsdet_jacobian(unwrap(finv), y)
63-
logd_orig = $func.origin, x_orig)
64-
logd = float(logd_orig + inv_ladj)
65-
neginf = oftype(logd, -Inf)
66-
return ifelse(
67-
# Zero density wins against infinite volume:
68-
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
69-
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70-
# Return constant -Inf to prevent problems with ForwardDiff:
71-
(isfinite(logd_orig) && (inv_ladj == -Inf)),
72-
neginf,
73-
logd,
74-
)
153+
@eval function $func::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
154+
f_inv = unwrap.finv)
155+
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)
156+
logd_orig = $func.origin, x)
157+
return _combine_logd_with_ladj(logd_orig, inv_ladj)
75158
end
76159

77-
@eval @inline function $func::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M}
78-
x = ν.finv(y)
79-
return $func.origin, x)
160+
@eval function $func::PushforwardMeasure{F,I,M,<:PushfwdRootMeasure}, y) where {F,I,M}
161+
f_inv = unwrap.finv)
162+
x = f_inv(y)
163+
logd_orig = $func.origin, x)
164+
return logd_orig
80165
end
81166
end
82167

83-
insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y))
168+
insupport(m::PushforwardMeasure, x) = insupport(transport_origin(m), to_origin(m, x))
84169

85170
function testvalue(::Type{T}, ν::PushforwardMeasure) where {T}
86171
ν.f(testvalue(T, parent(ν)))
87172
end
88173

89174
@inline function basemeasure::PushforwardMeasure)
90-
pushfwd.f, basemeasure(parent(ν)), NoVolCorr())
175+
pushfwd.f, basemeasure(parent(ν)), PushfwdRootMeasure())
176+
end
177+
178+
function rootmeasure(m::PushforwardMeasure{F,I,M,PushfwdRootMeasure}) where {F,I,M}
179+
pushfwd(m.f, rootmeasure(m.origin))
180+
end
181+
function rootmeasure(m::PushforwardMeasure{F,I,M,AdaptRootMeasure}) where {F,I,M}
182+
rootmeasure(m.origin)
91183
end
92184

93185
_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}()
94186
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
95187

96188
@inline getdof::MU) where {MU<:PushforwardMeasure} = getdof.origin)
189+
@inline getdof(m::_NonBijectivePusfwdMeasure) = MeasureBase.NoDOF{typeof(m)}()
97190

98191
# Bypass `checked_arg`, would require potentially costly transformation:
99192
@inline checked_arg(::PushforwardMeasure, x) = x
@@ -102,47 +195,53 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
102195
@inline from_origin::PushforwardMeasure, x) = ν.f(x)
103196
@inline to_origin::PushforwardMeasure, y) = ν.finv(y)
104197

198+
massof(m::PushforwardMeasure) = massof(transport_origin(m))
199+
105200
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where {T}
106-
return ν.f(rand(rng, T, parent(ν)))
201+
return ν.f(rand(rng, T, ν.origin))
107202
end
108203

109204
###############################################################################
110205
# pushfwd
111206

112-
export pushfwd
113-
114207
"""
115-
pushfwd(f, μ, volcorr = WithVolCorr())
208+
pushfwd(f, μ, style = AdaptRootMeasure())
116209
117210
Return the [pushforward
118211
measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
119212
[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
120213
121214
To manually specify an inverse, call
122-
`pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
215+
`pushfwd(InverseFunctions.setinverse(f, finv), μ, style)`.
123216
"""
124-
function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr())
125-
PushforwardMeasure(f, inverse(f), μ, volcorr)
126-
end
127-
128-
function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr())
129-
_pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr)
130-
end
217+
function pushfwd end
218+
export pushfwd
131219

132-
# Either both WithVolCorr or both NoVolCorr, so we can merge them
133-
function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V}
134-
pushfwd(fchain((μ.f, f)), μ.origin, v)
220+
@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure())
221+
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style)
222+
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style)
223+
224+
_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style)
225+
226+
function _pushfwd_impl(
227+
f,
228+
μ::PushforwardMeasure{F,I,M,S},
229+
style::S,
230+
) where {F,I,M,S<:PushFwdStyle}
231+
orig_μ = μ.origin
232+
new_f = fcomp(f, μ.f)
233+
new_f_inv = fcomp.finv, inverse(f))
234+
PushforwardMeasure(new_f, new_f_inv, orig_μ, style)
135235
end
136236

137-
function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v)
138-
PushforwardMeasure(f, inverse(f), μ, v)
139-
end
237+
_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
238+
_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
140239

141240
###############################################################################
142241
# pullback
143242

144243
"""
145-
pullback(f, μ, volcorr = WithVolCorr())
244+
pullbck(f, μ, style = AdaptRootMeasure())
146245
147246
A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
148247
map _from_ the support of a measure, a pullback requires a map _into_ the
@@ -154,8 +253,17 @@ in terms of the inverse function; the "forward" function is not used at all. In
154253
some cases, we may be focusing on log-density (and not, for example, sampling).
155254
156255
To manually specify an inverse, call
157-
`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
256+
`pullbck(InverseFunctions.setinverse(f, finv), μ, style)`.
158257
"""
159-
function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr())
160-
pushfwd(setinverse(inverse(f), f), μ, volcorr)
258+
function pullback end
259+
export pullback
260+
261+
@inline pullbck(f, μ) = _pullback_impl(f, μ, AdaptRootMeasure())
262+
@inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style)
263+
@inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style)
264+
265+
function _pullback_impl(f, μ, style = AdaptRootMeasure())
266+
pushfwd(setinverse(inverse(f), f), μ, style)
161267
end
268+
269+
@deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style)

src/transport.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -274,30 +274,3 @@ function Base.show(io::IO, f::TransportFunction)
274274
end
275275

276276
Base.show(io::IO, M::MIME"text/plain", f::TransportFunction) = show(io, f)
277-
278-
"""
279-
abstract type TransformVolCorr
280-
281-
Provides control over density correction by transform volume element.
282-
Either [`NoVolCorr()`](@ref) or [`WithVolCorr()`](@ref)
283-
"""
284-
abstract type TransformVolCorr end
285-
286-
"""
287-
NoVolCorr()
288-
289-
Indicate that density calculations should ignore the volume element of
290-
variate transformations. Should only be used in special cases in which
291-
the volume element has already been taken into account in a different
292-
way.
293-
"""
294-
struct NoVolCorr <: TransformVolCorr end
295-
296-
"""
297-
WithVolCorr()
298-
299-
Indicate that density calculations should take the volume element of
300-
variate transformations into account (typically via the
301-
log-abs-det-Jacobian of the transform).
302-
"""
303-
struct WithVolCorr <: TransformVolCorr end

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,21 @@ using InverseFunctions: FunctionWithInverse
164164

165165
unwrap(f) = f
166166
unwrap(f::FunctionWithInverse) = f.f
167+
168+
169+
fcomp(f, g) = fchain(g, f)
170+
fcomp(::typeof(identity), g) = g
171+
fcomp(f, ::typeof(identity)) = f
172+
fcomp(::typeof(identity), ::typeof(identity)) = identity
173+
174+
175+
near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32
176+
177+
isneginf(x) = isinf(x) && x < 0
178+
isposinf(x) = isinf(x) && x > 0
179+
180+
isapproxzero(x::T) where T<:Real = x zero(T)
181+
isapproxzero(A::AbstractArray) = all(isapproxzero, A)
182+
183+
isapproxone(x::T) where T<:Real = x one(T)
184+
isapproxone(A::AbstractArray) = all(isapproxone, A)

0 commit comments

Comments
 (0)