Skip to content

Commit 0fb2009

Browse files
committed
Add collection utils
1 parent 0f89a57 commit 0fb2009

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

ext/MeasureBaseChainRulesCoreExt.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,53 @@ using ChainRulesCore: NoTangent, ZeroTangent
77
import ChainRulesCore
88

99

10+
# = collection utils =========================================================
11+
12+
using MeasureBase: _dropfront, _dropback, _rev_cumsum, _exp_cumsum_log
13+
14+
function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
15+
result = _pushfront(v, x)
16+
function _pushfront_pullback(thunked_ΔΩ)
17+
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
18+
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
19+
end
20+
return result, _pushfront_pullback
21+
end
22+
23+
24+
function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
25+
result = _pushback(v, x)
26+
function _pushback_pullback(thunked_ΔΩ)
27+
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
28+
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
29+
end
30+
return result, _pushback_pullback
31+
end
32+
33+
34+
function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
35+
result = _rev_cumsum(xs)
36+
function _rev_cumsum_pullback(ΔΩ)
37+
∂xs = ChainRulesCore.@thunk cumsum(ChainRulesCore.unthunk(ΔΩ))
38+
(NoTangent(), ∂xs)
39+
end
40+
return result, _rev_cumsum_pullback
41+
end
42+
43+
44+
function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
45+
result = _exp_cumsum_log(xs)
46+
function _exp_cumsum_log_pullback(ΔΩ)
47+
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* ChainRulesCore.unthunk(ΔΩ))
48+
(NoTangent(), ∂xs)
49+
end
50+
return result, _exp_cumsum_log_pullback
51+
end
52+
53+
54+
# = measure functions ========================================================
55+
56+
1057
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
1158
y = _checksupport(cond, result)
1259
function _checksupport_pullback(ȳ)

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ using Compat
144144
using IrrationalConstants
145145

146146
include("static.jl")
147+
include("collection_utils.jl")
147148
include("smf.jl")
148149
include("getdof.jl")
149150
include("transport.jl")

src/collection_utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
function _pushfront(v::AbstractVector, x)
2+
T = promote_type(eltype(v), typeof(x))
3+
r = similar(v, T, length(eachindex(v)) + 1)
4+
r[firstindex(r)] = x
5+
r[firstindex(r)+1:lastindex(r)] = v
6+
r
7+
end
8+
9+
function _pushback(v::AbstractVector, x)
10+
T = promote_type(eltype(v), typeof(x))
11+
r = similar(v, T, length(eachindex(v)) + 1)
12+
r[lastindex(r)] = x
13+
r[firstindex(r):lastindex(r)-1] = v
14+
r
15+
end
16+
17+
_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)]
18+
19+
_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1]
20+
21+
_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs)))
22+
23+
# Equivalent to `cumprod(xs)``:
24+
_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs)))

0 commit comments

Comments
 (0)