@@ -7,6 +7,53 @@ using ChainRulesCore: NoTangent, ZeroTangent
7
7
import ChainRulesCore
8
8
9
9
10
+ # = collection utils =========================================================
11
+
12
+ using MeasureBase: _dropfront, _dropback, _rev_cumsum, _exp_cumsum_log
13
+
14
+ function ChainRulesCore. rrule (:: typeof (_pushfront), v:: AbstractVector , x)
15
+ result = _pushfront (v, x)
16
+ function _pushfront_pullback (thunked_ΔΩ)
17
+ ΔΩ = ChainRulesCore. unthunk (thunked_ΔΩ)
18
+ (NoTangent (), ΔΩ[firstindex (ΔΩ)+ 1 : lastindex (ΔΩ)], ΔΩ[firstindex (ΔΩ)])
19
+ end
20
+ return result, _pushfront_pullback
21
+ end
22
+
23
+
24
+ function ChainRulesCore. rrule (:: typeof (_pushback), v:: AbstractVector , x)
25
+ result = _pushback (v, x)
26
+ function _pushback_pullback (thunked_ΔΩ)
27
+ ΔΩ = ChainRulesCore. unthunk (thunked_ΔΩ)
28
+ (NoTangent (), ΔΩ[firstindex (ΔΩ): lastindex (ΔΩ)- 1 ], ΔΩ[lastindex (ΔΩ)])
29
+ end
30
+ return result, _pushback_pullback
31
+ end
32
+
33
+
34
+ function ChainRulesCore. rrule (:: typeof (_rev_cumsum), xs:: AbstractVector )
35
+ result = _rev_cumsum (xs)
36
+ function _rev_cumsum_pullback (ΔΩ)
37
+ ∂xs = ChainRulesCore. @thunk cumsum (ChainRulesCore. unthunk (ΔΩ))
38
+ (NoTangent (), ∂xs)
39
+ end
40
+ return result, _rev_cumsum_pullback
41
+ end
42
+
43
+
44
+ function ChainRulesCore. rrule (:: typeof (_exp_cumsum_log), xs:: AbstractVector )
45
+ result = _exp_cumsum_log (xs)
46
+ function _exp_cumsum_log_pullback (ΔΩ)
47
+ ∂xs = inv .(xs) .* _rev_cumsum (exp .(cumsum (log .(xs))) .* ChainRulesCore. unthunk (ΔΩ))
48
+ (NoTangent (), ∂xs)
49
+ end
50
+ return result, _exp_cumsum_log_pullback
51
+ end
52
+
53
+
54
+ # = measure functions ========================================================
55
+
56
+
10
57
@inline function ChainRulesCore. rrule (:: typeof (_checksupport), cond, result)
11
58
y = _checksupport (cond, result)
12
59
function _checksupport_pullback (ȳ)
0 commit comments