Skip to content

Commit d65c1eb

Browse files
committed
Add curried pushfwd and pullbck
1 parent 75c1384 commit d65c1eb

File tree

3 files changed

+68
-12
lines changed

3 files changed

+68
-12
lines changed

src/combinators/transformedmeasure.jl

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,25 +222,52 @@ To manually specify an inverse, call
222222
function pushfwd end
223223
export pushfwd
224224

225-
@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure())
226-
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style)
227-
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style)
225+
@inline pushfwd(f, μ) = _pushfwd_impl1(f, μ, AdaptRootMeasure())
226+
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl1(f, μ, style)
227+
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl1(f, μ, style)
228228

229-
_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style)
229+
_pushfwd_impl1(f, μ, style::PushFwdStyle) = _pushfwd_impl2(f, inverse(f), μ, style)
230+
_pushfwd_impl1(::typeof(identity), μ, ::AdaptRootMeasure) = μ
231+
_pushfwd_impl1(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
230232

231-
function _pushfwd_impl(
233+
_pushfwd_impl2(f, finv, μ, style::PushFwdStyle) = PushforwardMeasure(f, finv, μ, style)
234+
235+
function _pushfwd_impl2(
232236
f,
237+
finv,
233238
μ::PushforwardMeasure{F,I,M,S},
234239
style::S,
235240
) where {F,I,M,S<:PushFwdStyle}
236241
orig_μ = μ.origin
237242
new_f = fcomp(f, μ.f)
238-
new_f_inv = fcomp.finv, inverse(f))
243+
new_f_inv = fcomp.finv, finv)
239244
PushforwardMeasure(new_f, new_f_inv, orig_μ, style)
240245
end
241246

242-
_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
243-
_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
247+
struct _CurriedPushfwd{F,I,S<:PushFwdStyle} <: Function
248+
f::F
249+
finv::I
250+
style::S
251+
252+
function _CurriedPushfwd{F,I,S}(f::F, finv::I, style::S) where {F,I,S<:PushFwdStyle}
253+
new{F,I,S}(f, finv, style)
254+
end
255+
256+
function _CurriedPushfwd(f, finv, style::S) where {S<:PushFwdStyle}
257+
new{Core.Typeof(f),Core.Typeof(finv),S}(f, finv, style)
258+
end
259+
end
260+
261+
@inline (cf::_CurriedPushfwd{F,FI})(μ) where {F,FI} =
262+
_pushfwd_impl2(cf.f, cf.finv, μ, cf.style)
263+
264+
@inline pushfwd(f) = _curried_pushfwd_impl(f, AdaptRootMeasure())
265+
@inline pushfwd(f, style::AdaptRootMeasure) = _curried_pushfwd_impl(f, style)
266+
@inline pushfwd(f, style::PushfwdRootMeasure) = _curried_pushfwd_impl(f, style)
267+
268+
_curried_pushfwd_impl(f, style::PushFwdStyle) = _CurriedPushfwd(f, inverse(f), style)
269+
@inline _curried_pushfwd_impl(::typeof(identity), ::AdaptRootMeasure) = identity
270+
@inline _curried_pushfwd_impl(::typeof(identity), ::PushfwdRootMeasure) = identity
244271

245272
###############################################################################
246273
# pullback
@@ -267,8 +294,16 @@ export pullbck
267294
@inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style)
268295
@inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style)
269296

270-
function _pullback_impl(f, μ, style = AdaptRootMeasure())
271-
pushfwd(inverse(f), μ, style)
272-
end
297+
_pullback_impl(f, μ, style::PushFwdStyle) = _pushfwd_impl2(inverse(f), f, μ, style)
298+
_pullback_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
299+
_pullback_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
300+
301+
@inline pullbck(f) = _curried_pullbck_impl(f, AdaptRootMeasure())
302+
@inline pullbck(f, style::AdaptRootMeasure) = _curried_pullbck_impl(f, style)
303+
@inline pullbck(f, style::PushfwdRootMeasure) = _curried_pullbck_impl(f, style)
304+
305+
_curried_pullbck_impl(f, style::PushFwdStyle) = _CurriedPushfwd(inverse(f), f, style)
306+
@inline _curried_pullbck_impl(::typeof(identity), ::AdaptRootMeasure) = identity
307+
@inline _curried_pullbck_impl(::typeof(identity), ::PushfwdRootMeasure) = identity
273308

274309
@deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
1011
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1112
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/combinators/transformedmeasure.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import InverseFunctions: inverse, FunctionWithInverse, setinverse
1212
using IrrationalConstants: invsqrt2, sqrt2
1313
import ChangesOfVariables: with_logabsdet_jacobian
1414
using MeasureBase.Interface: transport_to, test_transport
15+
using FunctionChains: fchain
1516

1617
Φ(z) = erfc(-z * invsqrt2) / 2
1718
Φinv(p) = -erfcinv(2 * p) * sqrt2
@@ -106,7 +107,7 @@ using ChangesOfVariables
106107
# Test basic pushforward construction
107108
μ = StdNormal()
108109
f = exp
109-
ν = pushfwd(f, μ)
110+
ν = @inferred pushfwd(f, μ)
110111

111112
@test ν isa PushforwardMeasure
112113

@@ -167,10 +168,29 @@ using ChangesOfVariables
167168
# Test pullback
168169
pb = pullbck(f, ν)
169170
@test pb isa PushforwardMeasure
171+
@test pb.origin === μ
172+
@test pb.f === fchain(exp, log)
170173
@test logdensityof(pb, y) logdensityof(μ, y)
171174

172175
# Test deprecated pullback
173176
@test_deprecated pullback(f, μ)
177+
178+
# Test identity specializations
179+
for stylearg in [(), (AdaptRootMeasure(),), (PushfwdRootMeasure(),)]
180+
@test @inferred(pushfwd(identity, μ, stylearg...)) === μ
181+
@test @inferred(pullbck(identity, ν, stylearg...)) === ν
182+
end
183+
184+
# Test curried pushfwd and pullback
185+
for stylearg in [(), (AdaptRootMeasure(),), (PushfwdRootMeasure(),)]
186+
@test @inferred(pushfwd(f, stylearg...)(μ)) === pushfwd(f, μ, stylearg...)
187+
@test @inferred(pullbck(f, stylearg...)(ν)) === pullbck(f, ν, stylearg...)
188+
189+
@test @inferred(pushfwd(identity, stylearg...)) === identity
190+
@test @inferred(pullbck(identity, stylearg...)) === identity
191+
end
192+
193+
@test pushfwd(identity) === identity
174194
end
175195

176196
@testset "PushFwdStyle types" begin

0 commit comments

Comments
 (0)