Skip to content

Commit b0a40c5

Browse files
committed
change accumulate -> _accumulate!
1 parent 9b6452a commit b0a40c5

File tree

2 files changed

+61
-47
lines changed

2 files changed

+61
-47
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -424,34 +424,35 @@ _no_tuple_tangent(dx) = dx
424424

425425
# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.
426426

427+
# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient.
428+
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
429+
427430
function rrule(
428-
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
429-
init=_INIT, dims=nothing
431+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init,
430432
) where {G}
431-
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
432-
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
433-
# It's not supported by AD either, so no point calling back, and no regression:
434-
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
435-
# ERROR: Mutating arrays is not supported
436-
)
437-
list, start = if init === _INIT
433+
434+
list, start = if init === nothing
438435
_drop1(x), first(x)
439436
else
440-
x, init
437+
x, something(init)
441438
end
442439
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
443440
c, back = rrule_via_ad(config, op, a, b)
444441
end
445-
y = map(first, hobbits)
446-
if init === _INIT
442+
# y = map(first, hobbits)
443+
if init === nothing
447444
# `hobbits` is one short, and first one doesn't invoke `op`
448-
y = _vcat1(first(x), y)
445+
# y = _vcat1(first(x), y)
446+
y[1] = first(x)
447+
map!(first, @view(y[2:end]), hobbits)
448+
else
449+
map!(first, y, hobbits)
449450
end
450451
axe = axes(x)
451452
project = ProjectTo(x)
452453
function decumulate(dy)
453454
dy_plain = _no_tuple_tangent(unthunk(dy))
454-
rev_list = if init === _INIT
455+
rev_list = if init === nothing
455456
if VERSION >= v"1.6"
456457
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
457458
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
@@ -471,11 +472,14 @@ function rrule(
471472
end
472473
dop = sum(first, trio)
473474
dx = map(last, _reverse1(trio))
474-
if init == _INIT
475+
if init == nothing
475476
# `hobbits` is one short, and the first one is weird
476477
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
477478
end
478-
return (NoTangent(), dop, project(_reshape1(dx, axe)))
479+
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
480+
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
481+
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
482+
return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init)
479483
end
480484
return _reshape1(y, axe), decumulate
481485
end

test/rulesets/Base/mapreduce.jl

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
183183
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
184184
# now attached there, as this is the simplest way to handle `init` keyword.
185185
@eval using Base: mapfoldl_impl
186-
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
186+
_INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
187187

188188
# Simple
189189
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
@@ -286,57 +286,67 @@ end
286286
end # cumprod
287287

288288
@testset "accumulate(f, ::Array)" begin
289+
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
290+
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
291+
@eval using Base: _accumulate!
292+
289293
# Simple
290-
y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1)
294+
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
291295
@test y1 == [1, 2, 6, 24]
292-
@test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6])
296+
@test b1([1, 1, 1, 1])[3] isa ChainRulesCore.NotImplemented
297+
@test b1([1, 1, 1, 1])[4] == [33, 16, 10, 6]
298+
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
299+
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
293300

294-
if VERSION >= v"1.5"
295-
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
296-
@test y2 accumulate(/, [1 2; 3 4])
297-
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
298-
end
301+
# if VERSION >= v"1.5"
302+
# y2, b2 = rrule(CFG, _accumulate!, /, [1 2; 3 4])
303+
# @test y2 ≈ accumulate(/, [1 2; 3 4])
304+
# @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
305+
# end
299306

300307
# Test execution order
301308
c3 = Counter()
302-
y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3)
309+
y3, b3 = rrule(CFG, _accumulate!, c3, [0, 0, 0], [5, 7, 11], nothing, Some(3))
303310
@test c3 == Counter(3)
304311
@test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3)
305-
@test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear!
312+
@test b3([1, 1, 1])[4] == [29169, 602, 23] # the 23 is clear!
306313

307314
c4 = Counter()
308-
y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11])
315+
y4, b4 = rrule(CFG, _accumulate!, c4, [0, 0, 0], [5, 7, 11], nothing, nothing)
309316
@test c4 == Counter(2)
310317
@test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11])
311-
@test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22])
318+
@test b4([1, 1, 1])[4] == [417, 42*(1 + 12), 22]
312319

313320
# Test gradient of function
314-
y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11])
321+
y7, b7 = rrule(CFG, _accumulate!, Multiplier(3), [0, 0, 0], [5, 7, 11], nothing, nothing)
315322
@test y7 == accumulate((x,y)->x*y*3, [5, 7, 11])
316-
@test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315])
323+
@test b7([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 2345,)
324+
@test b7([1, 1, 1])[4] == [715, 510, 315]
317325

318-
y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3)
326+
y8, b8 = rrule(CFG, _accumulate!, Multiplier(13), [0, 0, 0], [5, 7, 11], nothing, Some(3))
319327
@test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3)
320-
@test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685])
328+
@test b8([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 588330,)
329+
@test b8([1, 1, 1])[4] == [511095, 365040, 230685]
321330
# To find these numbers:
322331
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
323332
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
324333

325334
# Finite differencing
326-
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
327-
if VERSION >= v"1.5"
328-
test_rrule(accumulate, /, 1 .+ rand(3, 4))
329-
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
330-
end
331-
end
332-
VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
333-
# Simple
334-
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
335-
@test y1 == (1, 2, 6, 24)
336-
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
337-
338-
# Finite differencing
339-
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
340-
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
335+
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
336+
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
337+
# if VERSION >= v"1.5"
338+
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
339+
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
340+
# end
341341
end
342+
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
343+
# # Simple
344+
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
345+
# @test y1 == (1, 2, 6, 24)
346+
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
347+
348+
# # Finite differencing
349+
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
350+
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
351+
# end
342352
end

0 commit comments

Comments
 (0)