@@ -222,25 +222,52 @@ To manually specify an inverse, call
222
222
function pushfwd end
223
223
export pushfwd
224
224
225
- @inline pushfwd (f, μ) = _pushfwd_impl (f, μ, AdaptRootMeasure ())
226
- @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl (f, μ, style)
227
- @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl (f, μ, style)
225
+ @inline pushfwd (f, μ) = _pushfwd_impl1 (f, μ, AdaptRootMeasure ())
226
+ @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl1 (f, μ, style)
227
+ @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl1 (f, μ, style)
228
228
229
- _pushfwd_impl (f, μ, style) = PushforwardMeasure (f, inverse (f), μ, style)
229
+ _pushfwd_impl1 (f, μ, style:: PushFwdStyle ) = _pushfwd_impl2 (f, inverse (f), μ, style)
230
+ _pushfwd_impl1 (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
231
+ _pushfwd_impl1 (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
230
232
231
- function _pushfwd_impl (
233
+ _pushfwd_impl2 (f, finv, μ, style:: PushFwdStyle ) = PushforwardMeasure (f, finv, μ, style)
234
+
235
+ function _pushfwd_impl2 (
232
236
f,
237
+ finv,
233
238
μ:: PushforwardMeasure{F,I,M,S} ,
234
239
style:: S ,
235
240
) where {F,I,M,S<: PushFwdStyle }
236
241
orig_μ = μ. origin
237
242
new_f = fcomp (f, μ. f)
238
- new_f_inv = fcomp (μ. finv, inverse (f) )
243
+ new_f_inv = fcomp (μ. finv, finv )
239
244
PushforwardMeasure (new_f, new_f_inv, orig_μ, style)
240
245
end
241
246
242
- _pushfwd_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
243
- _pushfwd_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
247
+ struct _CurriedPushfwd{F,I,S<: PushFwdStyle } <: Function
248
+ f:: F
249
+ finv:: I
250
+ style:: S
251
+
252
+ function _CurriedPushfwd {F,I,S} (f:: F , finv:: I , style:: S ) where {F,I,S<: PushFwdStyle }
253
+ new {F,I,S} (f, finv, style)
254
+ end
255
+
256
+ function _CurriedPushfwd (f, finv, style:: S ) where {S<: PushFwdStyle }
257
+ new {Core.Typeof(f),Core.Typeof(finv),S} (f, finv, style)
258
+ end
259
+ end
260
+
261
+ @inline (cf:: _CurriedPushfwd{F,FI} )(μ) where {F,FI} =
262
+ _pushfwd_impl2 (cf. f, cf. finv, μ, cf. style)
263
+
264
+ @inline pushfwd (f) = _curried_pushfwd_impl (f, AdaptRootMeasure ())
265
+ @inline pushfwd (f, style:: AdaptRootMeasure ) = _curried_pushfwd_impl (f, style)
266
+ @inline pushfwd (f, style:: PushfwdRootMeasure ) = _curried_pushfwd_impl (f, style)
267
+
268
+ _curried_pushfwd_impl (f, style:: PushFwdStyle ) = _CurriedPushfwd (f, inverse (f), style)
269
+ @inline _curried_pushfwd_impl (:: typeof (identity), :: AdaptRootMeasure ) = identity
270
+ @inline _curried_pushfwd_impl (:: typeof (identity), :: PushfwdRootMeasure ) = identity
244
271
245
272
# ##############################################################################
246
273
# pullback
@@ -267,8 +294,16 @@ export pullbck
267
294
@inline pullbck (f, μ, style:: AdaptRootMeasure ) = _pullback_impl (f, μ, style)
268
295
@inline pullbck (f, μ, style:: PushfwdRootMeasure ) = _pullback_impl (f, μ, style)
269
296
270
- function _pullback_impl (f, μ, style = AdaptRootMeasure ())
271
- pushfwd (inverse (f), μ, style)
272
- end
297
+ _pullback_impl (f, μ, style:: PushFwdStyle ) = _pushfwd_impl2 (inverse (f), f, μ, style)
298
+ _pullback_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
299
+ _pullback_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
300
+
301
+ @inline pullbck (f) = _curried_pullbck_impl (f, AdaptRootMeasure ())
302
+ @inline pullbck (f, style:: AdaptRootMeasure ) = _curried_pullbck_impl (f, style)
303
+ @inline pullbck (f, style:: PushfwdRootMeasure ) = _curried_pullbck_impl (f, style)
304
+
305
+ _curried_pullbck_impl (f, style:: PushFwdStyle ) = _CurriedPushfwd (inverse (f), f, style)
306
+ @inline _curried_pullbck_impl (:: typeof (identity), :: AdaptRootMeasure ) = identity
307
+ @inline _curried_pullbck_impl (:: typeof (identity), :: PushfwdRootMeasure ) = identity
273
308
274
309
@deprecate pullback (f, μ, style:: PushFwdStyle = AdaptRootMeasure ()) pullbck (f, μ, style)
0 commit comments