Skip to content

Commit 9b6452a

Browse files
committed
minimal change foldl -> mapfoldl_impl
1 parent 6aed351 commit 9b6452a

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,12 @@ end
339339
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
340340
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
341341

342+
# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
343+
342344
function rrule(
343-
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
344-
init=_InitialValue()
345+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
345346
) where {G}
346-
list, start = if init === _InitialValue()
347+
list, start = if init === _INIT
347348
_drop1(x), first(x)
348349
else
349350
# Case with init keyword is simpler to understand first!
@@ -367,11 +368,12 @@ function rrule(
367368
end
368369
dop = sum(first, trio)
369370
dx = map(last, _reverse1(trio))
370-
if init === _InitialValue()
371+
if init === _INIT
371372
# `hobbits` is one short
372373
dx = _vcat1(trio[end][2], dx)
373374
end
374-
return (NoTangent(), dop, project(_reshape1(dx, axe)))
375+
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
376+
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
375377
end
376378
return y, unfoldl
377379
end
@@ -402,7 +404,8 @@ _reverse1(x::Tuple) = reverse(x)
402404
_drop1(x::Tuple) = Base.tail(x)
403405
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
404406

405-
struct _InitialValue end # Old versions don't have `Base._InitialValue`
407+
# struct _InitialValue end # Old versions don't have `Base._InitialValue`
408+
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
406409

407410
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
408411
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
@@ -423,15 +426,15 @@ _no_tuple_tangent(dx) = dx
423426

424427
function rrule(
425428
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
426-
init=_InitialValue(), dims=nothing
429+
init=_INIT, dims=nothing
427430
) where {G}
428431
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
429432
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
430433
# It's not supported by AD either, so no point calling back, and no regression:
431434
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
432435
# ERROR: Mutating arrays is not supported
433436
)
434-
list, start = if init === _InitialValue()
437+
list, start = if init === _INIT
435438
_drop1(x), first(x)
436439
else
437440
x, init
@@ -440,15 +443,15 @@ function rrule(
440443
c, back = rrule_via_ad(config, op, a, b)
441444
end
442445
y = map(first, hobbits)
443-
if init === _InitialValue()
446+
if init === _INIT
444447
# `hobbits` is one short, and first one doesn't invoke `op`
445448
y = _vcat1(first(x), y)
446449
end
447450
axe = axes(x)
448451
project = ProjectTo(x)
449452
function decumulate(dy)
450453
dy_plain = _no_tuple_tangent(unthunk(dy))
451-
rev_list = if init === _InitialValue()
454+
rev_list = if init === _INIT
452455
if VERSION >= v"1.6"
453456
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
454457
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
@@ -468,7 +471,7 @@ function rrule(
468471
end
469472
dop = sum(first, trio)
470473
dx = map(last, _reverse1(trio))
471-
if init == _InitialValue()
474+
if init == _INIT
472475
# `hobbits` is one short, and the first one is weird
473476
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
474477
end

test/rulesets/Base/mapreduce.jl

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -180,60 +180,71 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
180180
end # prod
181181

182182
@testset "foldl(f, ::Array)" begin
183+
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
184+
# now attached there, as this is the simplest way to handle `init` keyword.
185+
@eval using Base: mapfoldl_impl
186+
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
187+
183188
# Simple
184-
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
189+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
185190
@test y1 == 6
186-
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])
191+
@test b1(7)[1:3] == (NoTangent(), NoTangent(), NoTangent())
192+
@test b1(7)[4] isa ChainRulesCore.NotImplemented
193+
@test b1(7)[5] == [42, 21, 14]
187194

188-
y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
195+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, [1 2; 0 4]) # without init, needs vcat
189196
@test y2 == 0
190-
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape
197+
@test b2(8)[5] == [0 0; 64 0] # matrix, needs reshape
191198

192199
# Test execution order
193200
c5 = Counter()
194-
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
201+
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, [5, 7, 11])
195202
@test c5 == Counter(2)
196203
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
197-
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
204+
@test b5(1)[5] == [12*32, 12*42, 22]
198205
@test c5 == Counter(42)
199206

200207
c6 = Counter()
201-
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
208+
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, [5, 7, 11])
202209
@test c6 == Counter(3)
203210
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
204-
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
211+
@test b6(1)[5] == [63*33*13, 43*13, 23]
205212
@test c6 == Counter(63)
206213

207214
# Test gradient of function
208-
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
215+
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, [5, 7, 11])
209216
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
210-
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])
217+
b7_1 = b7(1)
218+
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
219+
@test b7_1[5] == [693, 495, 315]
211220

212-
y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
221+
y8, b8 = rrule(CFG, mapfoldl_impl, identity, Multiplier(13), 3, [5, 7, 11])
213222
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
214-
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
223+
b8_1 = b8(1)
224+
@test b8_1[3] == Tangent{Multiplier{Int}}(x = 585585,)
225+
@test b8_1[5] == [507507, 362505, 230685]
215226
# To find these numbers:
216227
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217228
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218229

219230
# Finite differencing
220-
test_rrule(foldl, /, 1 .+ rand(3,4))
221-
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
222-
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
223-
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
231+
test_rrule(mapfoldl_impl, identity, /, _INIT, 1 .+ rand(3,4))
232+
test_rrule(mapfoldl_impl, identity, *, rand(ComplexF64), rand(ComplexF64,3,4))
233+
test_rrule(mapfoldl_impl, identity, +, rand(ComplexF64), rand(ComplexF64,7))
234+
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
224235
end
225236
VERSION >= v"1.5" && @testset "foldl(f, ::Tuple)" begin
226-
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
237+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
227238
@test y1 == 6
228-
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))
239+
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)
229240

230-
y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
241+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
231242
@test y2 == 0
232-
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))
243+
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)
233244

234245
# Finite differencing
235-
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
236-
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
246+
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
247+
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
237248
end
238249
end
239250

0 commit comments

Comments
 (0)