Skip to content

Commit fff84b5

Browse files
committed
separate rule for foldl(::Tuple)
1 parent 87b4ea4 commit fff84b5

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,17 +417,73 @@ end
417417
end
418418

419419
#####
420-
##### `foldl`
420+
#####
421+
##### `foldl(f, ::Tuple)`
421422
#####
422423

423424
# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
424-
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.
425+
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.
426+
427+
# The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
428+
# which is handled below. For tuples, `reduce` also comes here.
429+
430+
function rrule(
431+
config::RuleConfig{>:HasReverseMode},
432+
::typeof(Base.mapfoldl_impl),
433+
::typeof(identity),
434+
op::G,
435+
init::Base._InitialValue,
436+
x::Tuple;
437+
) where {G}
438+
hobbits = accumulate(Base.tail(x); init=(first(x), nothing)) do (a, _), b
439+
# Here `a` is what we would normally cary forward, and `_` ignores
440+
# the previous iteration's pullback function (needed later),
441+
# while `b` is the fresh input from `list` as usual.
442+
c, back = rrule_via_ad(config, op, a, b)
443+
# We don't really need to store every `c`, last one is `foldl` output.
444+
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
445+
end
446+
y = first(last(hobbits))
447+
project = ProjectTo(x)
448+
function foldl_pullback_tuple(dy)
449+
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
450+
ds, da, db = back(dc)
451+
# Don't need to store every `da`, need one for the next iteration + the last.
452+
end
453+
dop = sum(first, trio)
454+
dx = (trio[end][2], reverse(map(last, trio))...)
455+
return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx))
456+
end
457+
return y, foldl_pullback_tuple
458+
end
459+
460+
function rrule(
461+
config::RuleConfig{>:HasReverseMode},
462+
::typeof(Base.mapfoldl_impl),
463+
::typeof(identity),
464+
op::G,
465+
init,
466+
x::Tuple;
467+
) where {G}
468+
# Treat `init` by simply appending it to the `x`:
469+
y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...))
470+
project_x = ProjectTo(x)
471+
project_in = ProjectTo(init)
472+
function foldl_pullback_tuple_init(dy)
473+
_, _, dop, _, dxplus = back(dy)
474+
return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base.tail(dxplus)))
475+
end
476+
return y, foldl_pullback_tuple_init
477+
end
425478

426-
# The implementation aims to be efficient for both tuples and arrays, although using accumulate
427-
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
428-
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
479+
#####
480+
##### `foldl(f, ::Array)`
481+
#####
429482

430-
# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
483+
# The implementation was originally for both tuples and arrays, although using accumulate
484+
# to carry intermediate results along creates arrays of tuples which could be avoided.
485+
# Using a loop can be a few times faster, this should be replaced.
486+
# Note also that it does not return a gradient for `init`.
431487

432488
function rrule(
433489
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
@@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x)
486542
_drop1(x::Tuple) = Base.tail(x)
487543
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
488544

489-
# struct _InitialValue end # Old versions don't have `Base._InitialValue`
490-
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
545+
const _INIT = Base._InitialValue()
491546

492547
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
493548
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)

test/rulesets/Base/mapreduce.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
44

55
const CFG = ChainRulesTestUtils.ADviaRuleConfig()
66

7+
using Base: mapfoldl_impl, _accumulate! # for foldl & accumulate rules
8+
const _INIT = Base._InitialValue()
9+
710
@testset "Reductions" begin
811
@testset "sum(::Tuple)" begin
912
test_frule(sum, Tuple(rand(5)))
@@ -216,8 +219,6 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
216219
@testset "foldl(f, ::Array)" begin
217220
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
218221
# now attached there, as this is the simplest way to handle `init` keyword.
219-
@eval using Base: mapfoldl_impl
220-
_INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
221222

222223
# Simple
223224
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
@@ -268,18 +269,39 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
268269
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
269270
end
270271
@testset "foldl(f, ::Tuple)" begin
271-
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
272272
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
273273
@test y1 == 6
274274
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)
275275

276276
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
277277
@test y2 == 0
278278
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)
279+
280+
# Test execution order
281+
c5 = Counter()
282+
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, (5, 7, 11))
283+
@test c5 == Counter(2)
284+
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), (5, 7, 11))
285+
@test collect(b5(1)[5]) == [12*32, 12*42, 22]
286+
@test c5 == Counter(42)
287+
288+
c6 = Counter()
289+
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, (5, 7, 11))
290+
@test c6 == Counter(3)
291+
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), (5, 7, 11), init=3)
292+
@test collect(b6(1)[5]) == [63*33*13, 43*13, 23]
293+
@test c6 == Counter(63)
294+
295+
# Test gradient of function
296+
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, (5, 7, 11))
297+
@test y7 == foldl((x,y)->x*y*3, (5, 7, 11))
298+
b7_1 = b7(1)
299+
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
300+
@test collect(b7_1[5]) == [693, 495, 315]
279301

280302
# Finite differencing
281303
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
282-
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
304+
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
283305
end
284306
end
285307

0 commit comments

Comments
 (0)