Skip to content

Commit 87b4ea4

Browse files
committed
change accumulate -> _accumulate!
1 parent e4b20da commit 87b4ea4

File tree

2 files changed

+57
-27
lines changed

2 files changed

+57
-27
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -506,34 +506,35 @@ _no_tuple_tangent(dx) = dx
506506

507507
# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.
508508

509+
# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient.
510+
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
511+
509512
function rrule(
510-
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
511-
init=_INIT, dims=nothing
513+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init,
512514
) where {G}
513-
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
514-
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
515-
# It's not supported by AD either, so no point calling back, and no regression:
516-
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
517-
# ERROR: Mutating arrays is not supported
518-
)
519-
list, start = if init === _INIT
515+
516+
list, start = if init === nothing
520517
_drop1(x), first(x)
521518
else
522-
x, init
519+
x, something(init)
523520
end
524521
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
525522
c, back = rrule_via_ad(config, op, a, b)
526523
end
527-
y = map(first, hobbits)
528-
if init === _INIT
524+
# y = map(first, hobbits)
525+
if init === nothing
529526
# `hobbits` is one short, and first one doesn't invoke `op`
530-
y = _vcat1(first(x), y)
527+
# y = _vcat1(first(x), y)
528+
y[1] = first(x)
529+
map!(first, @view(y[2:end]), hobbits)
530+
else
531+
map!(first, y, hobbits)
531532
end
532533
axe = axes(x)
533534
project = ProjectTo(x)
534535
function decumulate(dy)
535536
dy_plain = _no_tuple_tangent(unthunk(dy))
536-
rev_list = if init === _InitialValue()
537+
rev_list = if init === nothing
537538
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
538539
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
539540
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
@@ -546,11 +547,14 @@ function rrule(
546547
end
547548
dop = sum(first, trio)
548549
dx = map(last, _reverse1(trio))
549-
if init == _INIT
550+
if init == nothing
550551
# `hobbits` is one short, and the first one is weird
551552
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
552553
end
553-
return (NoTangent(), dop, project(_reshape1(dx, axe)))
554+
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
555+
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
556+
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
557+
return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init)
554558
end
555559
return _reshape1(y, axe), decumulate
556560
end

test/rulesets/Base/mapreduce.jl

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
217217
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
218218
# now attached there, as this is the simplest way to handle `init` keyword.
219219
@eval using Base: mapfoldl_impl
220-
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
220+
_INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
221221

222222
# Simple
223223
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
@@ -337,36 +337,45 @@ end
337337
end # cumprod
338338

339339
@testset "accumulate(f, ::Array)" begin
340+
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
341+
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
342+
@eval using Base: _accumulate!
343+
340344
# Simple
341-
y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1)
345+
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
342346
@test y1 == [1, 2, 6, 24]
343-
@test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6])
347+
@test b1([1, 1, 1, 1])[3] isa ChainRulesCore.NotImplemented
348+
@test b1([1, 1, 1, 1])[4] == [33, 16, 10, 6]
349+
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
350+
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
344351

345352
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
346353
@test y2 accumulate(/, [1 2; 3 4])
347354
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
348355

349356
# Test execution order
350357
c3 = Counter()
351-
y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3)
358+
y3, b3 = rrule(CFG, _accumulate!, c3, [0, 0, 0], [5, 7, 11], nothing, Some(3))
352359
@test c3 == Counter(3)
353360
@test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3)
354-
@test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear!
361+
@test b3([1, 1, 1])[4] == [29169, 602, 23] # the 23 is clear!
355362

356363
c4 = Counter()
357-
y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11])
364+
y4, b4 = rrule(CFG, _accumulate!, c4, [0, 0, 0], [5, 7, 11], nothing, nothing)
358365
@test c4 == Counter(2)
359366
@test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11])
360-
@test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22])
367+
@test b4([1, 1, 1])[4] == [417, 42*(1 + 12), 22]
361368

362369
# Test gradient of function
363-
y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11])
370+
y7, b7 = rrule(CFG, _accumulate!, Multiplier(3), [0, 0, 0], [5, 7, 11], nothing, nothing)
364371
@test y7 == accumulate((x,y)->x*y*3, [5, 7, 11])
365-
@test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315])
372+
@test b7([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 2345,)
373+
@test b7([1, 1, 1])[4] == [715, 510, 315]
366374

367-
y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3)
375+
y8, b8 = rrule(CFG, _accumulate!, Multiplier(13), [0, 0, 0], [5, 7, 11], nothing, Some(3))
368376
@test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3)
369-
@test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685])
377+
@test b8([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 588330,)
378+
@test b8([1, 1, 1])[4] == [511095, 365040, 230685]
370379
# To find these numbers:
371380
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
372381
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
@@ -385,5 +394,22 @@ end
385394
# Finite differencing
386395
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
387396
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
397+
398+
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
399+
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
400+
# if VERSION >= v"1.5"
401+
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
402+
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
403+
# end
388404
end
405+
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
406+
# # Simple
407+
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
408+
# @test y1 == (1, 2, 6, 24)
409+
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
410+
411+
# # Finite differencing
412+
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
413+
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
414+
# end
389415
end

0 commit comments

Comments
 (0)