1- # TODO : Compare with ChangesOfVariables.jl
1+ """
2+ abstract type PushFwdStyle
3+
4+ Provides the behavior of a measure's [`rootmeasure`](@ref) under a
5+ pushforward. Either [`AdaptRootMeasure()`](@ref) or
6+ [`PushfwdRootMeasure()`](@ref)
7+ """
8+ abstract type PushFwdStyle end
9+ export PushFwdStyle
10+
11+ const TransformVolCorr = PushFwdStyle
12+
13+ """
14+ AdaptRootMeasure()
215
3- using InverseFunctions: FunctionWithInverse
16+ Indicates that when applying a pushforward to a measure, it's
17+ [`rootmeasure`](@ref) not not be pushed forward. Instead, the root measure
18+ should be kept just "reshaped" to the new measurable space if necessary.
19+
20+ Density calculations for pushforward measures constructed with
21+ `AdaptRootMeasure()` will take take the volume element of variate
22+ transform (typically via the log-abs-det-Jacobian of the transform) into
23+ account.
24+ """
25+ struct AdaptRootMeasure <: TransformVolCorr end
26+ export AdaptRootMeasure
27+
28+ const WithVolCorr = AdaptRootMeasure
29+
30+ """
31+ PushfwdRootMeasure()
32+
33+ Indicates than when applying a pushforward to a measure, it's
34+ [`rootmeasure`](@ref) should be pushed forward with the same function.
35+
36+ Density calculations for pushforward measures constructed with
37+ `PushfwdRootMeasure()` will ignore the volume element of the variate
38+ transform.
39+ """
40+ struct PushfwdRootMeasure <: TransformVolCorr end
41+ export PushfwdRootMeasure
42+
43+ const NoVolCorr = PushfwdRootMeasure
444
545abstract type AbstractTransformedMeasure <: AbstractMeasure end
646
@@ -19,23 +59,37 @@ function parent(::AbstractTransformedMeasure) end
1959export PushforwardMeasure
2060
2161"""
22- struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr } <: AbstractPushforward
62+ struct PushforwardMeasure{F,I,M,S<:PushFwdStyle } <: AbstractPushforward
2363 f :: F
2464 finv :: I
2565 origin :: M
26- volcorr :: VC
66+ style :: S
2767 end
2868
2969 Users should not call `PushforwardMeasure` directly. Instead call or add
3070 methods to `pushfwd`.
3171"""
32- struct PushforwardMeasure{F,I,M,VC <: TransformVolCorr } <: AbstractPushforward
72+ struct PushforwardMeasure{F,I,M,S <: PushFwdStyle } <: AbstractPushforward
3373 f:: F
3474 finv:: I
3575 origin:: M
36- volcorr:: VC
76+ style:: S
77+
78+ function PushforwardMeasure {F,I,M,S} (f:: F , finv:: I , origin:: M , style:: S ) where {F,I,M,S<: PushFwdStyle }
79+ new {F,I,M,S} (f, finv, origin, style)
80+ end
81+
82+ function PushforwardMeasure (f, finv, origin:: M , style:: S ) where {M,S<: PushFwdStyle }
83+ new {Core.Typeof(f),Core.Typeof(finv),M,S} (f, finv, origin, style)
84+ end
3785end
3886
87+ const _NonBijectivePusfwdMeasure{M<: PushforwardMeasure ,S<: PushFwdStyle } = Union{
88+ PushforwardMeasure{<: Any ,<: NoInverse ,M,S},
89+ PushforwardMeasure{<: NoInverse ,<: Any ,M,S},
90+ PushforwardMeasure{<: NoInverse ,<: NoInverse ,M,S},
91+ }
92+
3993gettransform (ν:: PushforwardMeasure ) = ν. f
4094parent (ν:: PushforwardMeasure ) = ν. origin
4195
4599
46100# TODO : THIS IS ALMOST CERTAINLY WRONG
47101# @inline function logdensity_rel(
48- # ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr },
49- # β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr },
102+ # ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure },
103+ # β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure },
50104# y,
51105# ) where {FF1,IF1,M1,FF2,IF2,M2}
52106# x = β.inv_f(y)
53107# f = ν.inv_f ∘ β.f
54108# inv_f = β.inv_f ∘ ν.f
55- # logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr ()), β.origin, x)
109+ # logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure ()), β.origin, x)
56110# end
57111
112+ # TODO : Would profit from custom pullback:
113+ function _combine_logd_with_ladj (logd_orig:: Real , ladj:: Real )
114+ logd_result = logd_orig + ladj
115+ R = typeof (logd_result)
116+
117+ if isnan (logd_result) && isneginf (logd_orig) && isposinf (ladj)
118+ # Zero μ wins against infinite volume:
119+ R (- Inf ):: R
120+ elseif isfinite (logd_orig) && isneginf (ladj)
121+ # Maybe also for isneginf(logd_orig) && isfinite(ladj) ?
122+ # Return constant -Inf to prevent problems with ForwardDiff:
123+ # R(-Inf)
124+ near_neg_inf (R):: R # Avoids AdvancedHMC warnings
125+ else
126+ logd_result:: R
127+ end
128+ end
129+
130+ function logdensityof (
131+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure} ),
132+ @nospecialize (v:: Any )
133+ ) where {M}
134+ throw (
135+ ArgumentError (
136+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
137+ ),
138+ )
139+ end
140+
141+ function logdensityof (
142+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure} ),
143+ @nospecialize (v:: Any )
144+ ) where {M}
145+ throw (
146+ ArgumentError (
147+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
148+ ),
149+ )
150+ end
151+
58152for func in [:logdensityof , :logdensity_def ]
59- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:WithVolCorr} , y) where {F,I,M}
60- f = ν. f
61- finv = ν. finv
62- x_orig, inv_ladj = with_logabsdet_jacobian (unwrap (finv), y)
63- logd_orig = $ func (ν. origin, x_orig)
64- logd = float (logd_orig + inv_ladj)
65- neginf = oftype (logd, - Inf )
66- return ifelse (
67- # Zero density wins against infinite volume:
68- (isnan (logd) && logd_orig == - Inf && inv_ladj == + Inf ) ||
69- # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70- # Return constant -Inf to prevent problems with ForwardDiff:
71- (isfinite (logd_orig) && (inv_ladj == - Inf )),
72- neginf,
73- logd,
74- )
153+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:AdaptRootMeasure} , y) where {F,I,M}
154+ f_inv = unwrap (ν. finv)
155+ x, inv_ladj = with_logabsdet_jacobian (f_inv, y)
156+ logd_orig = $ func (ν. origin, x)
157+ return _combine_logd_with_ladj (logd_orig, inv_ladj)
75158 end
76159
77- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:NoVolCorr} , y) where {F,I,M}
78- x = ν. finv (y)
79- return $ func (ν. origin, x)
160+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:PushfwdRootMeasure} , y) where {F,I,M}
161+ f_inv = unwrap (ν. finv)
162+ x = f_inv (y)
163+ logd_orig = $ func (ν. origin, x)
164+ return logd_orig
80165 end
81166end
82167
83- insupport (ν :: PushforwardMeasure , y ) = insupport (ν . origin, ν . finv (y ))
168+ insupport (m :: PushforwardMeasure , x ) = insupport (transport_origin (m), to_origin (m, x ))
84169
85170function testvalue (:: Type{T} , ν:: PushforwardMeasure ) where {T}
86171 ν. f (testvalue (T, parent (ν)))
87172end
88173
89174@inline function basemeasure (ν:: PushforwardMeasure )
90- pushfwd (ν. f, basemeasure (parent (ν)), NoVolCorr ())
175+ pushfwd (ν. f, basemeasure (parent (ν)), PushfwdRootMeasure ())
176+ end
177+
178+ function rootmeasure (m:: PushforwardMeasure{F,I,M,PushfwdRootMeasure} ) where {F,I,M}
179+ pushfwd (m. f, rootmeasure (m. origin))
180+ end
181+ function rootmeasure (m:: PushforwardMeasure{F,I,M,AdaptRootMeasure} ) where {F,I,M}
182+ rootmeasure (m. origin)
91183end
92184
93185_pushfwd_dof (:: Type{MU} , :: Type , dof) where {MU} = NoDOF {MU} ()
94186_pushfwd_dof (:: Type{MU} , :: Type{<:Tuple{Any,Real}} , dof) where {MU} = dof
95187
96188@inline getdof (ν:: MU ) where {MU<: PushforwardMeasure } = getdof (ν. origin)
189+ @inline getdof (m:: _NonBijectivePusfwdMeasure ) = MeasureBase. NoDOF {typeof(m)} ()
97190
98191# Bypass `checked_arg`, would require potentially costly transformation:
99192@inline checked_arg (:: PushforwardMeasure , x) = x
@@ -102,47 +195,53 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
102195@inline from_origin (ν:: PushforwardMeasure , x) = ν. f (x)
103196@inline to_origin (ν:: PushforwardMeasure , y) = ν. finv (y)
104197
198+ massof (m:: PushforwardMeasure ) = massof (transport_origin (m))
199+
105200function Base. rand (rng:: AbstractRNG , :: Type{T} , ν:: PushforwardMeasure ) where {T}
106- return ν. f (rand (rng, T, parent (ν) ))
201+ return ν. f (rand (rng, T, ν . origin ))
107202end
108203
109204# ##############################################################################
110205# pushfwd
111206
112- export pushfwd
113-
114207"""
115- pushfwd(f, μ, volcorr = WithVolCorr ())
208+ pushfwd(f, μ, style = AdaptRootMeasure ())
116209
117210Return the [pushforward
118211measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
119212[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
120213
121214To manually specify an inverse, call
122- `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr )`.
215+ `pushfwd(InverseFunctions.setinverse(f, finv), μ, style )`.
123216"""
124- function pushfwd (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
125- PushforwardMeasure (f, inverse (f), μ, volcorr)
126- end
127-
128- function pushfwd (f, μ:: PushforwardMeasure , volcorr:: TransformVolCorr = WithVolCorr ())
129- _pushfwd_of_pushfwd (f, μ, μ. volcorr, volcorr)
130- end
217+ function pushfwd end
218+ export pushfwd
131219
132- # Either both WithVolCorr or both NoVolCorr, so we can merge them
133- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , :: V , v:: V ) where {V}
134- pushfwd (fchain ((μ. f, f)), μ. origin, v)
220+ @inline pushfwd (f, μ) = _pushfwd_impl (f, μ, AdaptRootMeasure ())
221+ @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl (f, μ, style)
222+ @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl (f, μ, style)
223+
224+ _pushfwd_impl (f, μ, style) = PushforwardMeasure (f, inverse (f), μ, style)
225+
226+ function _pushfwd_impl (
227+ f,
228+ μ:: PushforwardMeasure{F,I,M,S} ,
229+ style:: S ,
230+ ) where {F,I,M,S<: PushFwdStyle }
231+ orig_μ = μ. origin
232+ new_f = fcomp (f, μ. f)
233+ new_f_inv = fcomp (μ. finv, inverse (f))
234+ PushforwardMeasure (new_f, new_f_inv, orig_μ, style)
135235end
136236
137- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , _, v)
138- PushforwardMeasure (f, inverse (f), μ, v)
139- end
237+ _pushfwd_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
238+ _pushfwd_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
140239
141240# ##############################################################################
142241# pullback
143242
144243"""
145- pullback (f, μ, volcorr = WithVolCorr ())
244+ pullbck (f, μ, style = AdaptRootMeasure ())
146245
147246A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
148247map _from_ the support of a measure, a pullback requires a map _into_ the
@@ -154,8 +253,17 @@ in terms of the inverse function; the "forward" function is not used at all. In
154253some cases, we may be focusing on log-density (and not, for example, sampling).
155254
156255To manually specify an inverse, call
157- `pullback (InverseFunctions.setinverse(f, finv), μ, volcorr )`.
256+ `pullbck (InverseFunctions.setinverse(f, finv), μ, style )`.
158257"""
159- function pullback (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
160- pushfwd (setinverse (inverse (f), f), μ, volcorr)
258+ function pullback end
259+ export pullback
260+
261+ @inline pullbck (f, μ) = _pullback_impl (f, μ, AdaptRootMeasure ())
262+ @inline pullbck (f, μ, style:: AdaptRootMeasure ) = _pullback_impl (f, μ, style)
263+ @inline pullbck (f, μ, style:: PushfwdRootMeasure ) = _pullback_impl (f, μ, style)
264+
265+ function _pullback_impl (f, μ, style = AdaptRootMeasure ())
266+ pushfwd (setinverse (inverse (f), f), μ, style)
161267end
268+
269+ @deprecate pullback (f, μ, style:: PushFwdStyle = AdaptRootMeasure ()) pullbck (f, μ, style)
0 commit comments