Skip to content

Commit e4b20da

Browse files
committed
minimal change foldl -> mapfoldl_impl
1 parent 7faaf5d commit e4b20da

File tree

2 files changed

+46
-31
lines changed

2 files changed

+46
-31
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,12 @@ end
427427
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
428428
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
429429

430+
# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
431+
430432
function rrule(
431-
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
432-
init=_InitialValue()
433+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
433434
) where {G}
434-
list, start = if init === _InitialValue()
435+
list, start = if init === _INIT
435436
_drop1(x), first(x)
436437
else
437438
# Case with init keyword is simpler to understand first!
@@ -455,11 +456,12 @@ function rrule(
455456
end
456457
dop = sum(first, trio)
457458
dx = map(last, _reverse1(trio))
458-
if init === _InitialValue()
459+
if init === _INIT
459460
# `hobbits` is one short
460461
dx = _vcat1(trio[end][2], dx)
461462
end
462-
return (NoTangent(), dop, project(_reshape1(dx, axe)))
463+
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
464+
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
463465
end
464466
return y, unfoldl
465467
end
@@ -484,7 +486,8 @@ _reverse1(x::Tuple) = reverse(x)
484486
_drop1(x::Tuple) = Base.tail(x)
485487
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
486488

487-
struct _InitialValue end # Old versions don't have `Base._InitialValue`
489+
# struct _InitialValue end # Old versions don't have `Base._InitialValue`
490+
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
488491

489492
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
490493
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
@@ -505,15 +508,15 @@ _no_tuple_tangent(dx) = dx
505508

506509
function rrule(
507510
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
508-
init=_InitialValue(), dims=nothing
511+
init=_INIT, dims=nothing
509512
) where {G}
510513
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
511514
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
512515
# It's not supported by AD either, so no point calling back, and no regression:
513516
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
514517
# ERROR: Mutating arrays is not supported
515518
)
516-
list, start = if init === _InitialValue()
519+
list, start = if init === _INIT
517520
_drop1(x), first(x)
518521
else
519522
x, init
@@ -522,7 +525,7 @@ function rrule(
522525
c, back = rrule_via_ad(config, op, a, b)
523526
end
524527
y = map(first, hobbits)
525-
if init === _InitialValue()
528+
if init === _INIT
526529
# `hobbits` is one short, and first one doesn't invoke `op`
527530
y = _vcat1(first(x), y)
528531
end
@@ -543,7 +546,7 @@ function rrule(
543546
end
544547
dop = sum(first, trio)
545548
dx = map(last, _reverse1(trio))
546-
if init == _InitialValue()
549+
if init == _INIT
547550
# `hobbits` is one short, and the first one is weird
548551
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
549552
end

test/rulesets/Base/mapreduce.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -214,60 +214,72 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
214214
end # prod
215215

216216
@testset "foldl(f, ::Array)" begin
217+
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
218+
# now attached there, as this is the simplest way to handle `init` keyword.
219+
@eval using Base: mapfoldl_impl
220+
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
221+
217222
# Simple
218-
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
223+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
219224
@test y1 == 6
220-
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])
225+
@test b1(7)[1:3] == (NoTangent(), NoTangent(), NoTangent())
226+
@test b1(7)[4] isa ChainRulesCore.NotImplemented
227+
@test b1(7)[5] == [42, 21, 14]
221228

222-
y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
229+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, [1 2; 0 4]) # without init, needs vcat
223230
@test y2 == 0
224-
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape
231+
@test b2(8)[5] == [0 0; 64 0] # matrix, needs reshape
225232

226233
# Test execution order
227234
c5 = Counter()
228-
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
235+
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, [5, 7, 11])
229236
@test c5 == Counter(2)
230237
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
231-
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
238+
@test b5(1)[5] == [12*32, 12*42, 22]
232239
@test c5 == Counter(42)
233240

234241
c6 = Counter()
235-
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
242+
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, [5, 7, 11])
236243
@test c6 == Counter(3)
237244
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
238-
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
245+
@test b6(1)[5] == [63*33*13, 43*13, 23]
239246
@test c6 == Counter(63)
240247

241248
# Test gradient of function
242-
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
249+
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, [5, 7, 11])
243250
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
244-
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])
251+
b7_1 = b7(1)
252+
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
253+
@test b7_1[5] == [693, 495, 315]
245254

246-
y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
255+
y8, b8 = rrule(CFG, mapfoldl_impl, identity, Multiplier(13), 3, [5, 7, 11])
247256
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
248-
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
257+
b8_1 = b8(1)
258+
@test b8_1[3] == Tangent{Multiplier{Int}}(x = 585585,)
259+
@test b8_1[5] == [507507, 362505, 230685]
249260
# To find these numbers:
250261
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
251262
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
252263

253264
# Finite differencing
254-
test_rrule(foldl, /, 1 .+ rand(3,4))
255-
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
256-
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
257-
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
265+
test_rrule(mapfoldl_impl, identity, /, _INIT, 1 .+ rand(3,4))
266+
test_rrule(mapfoldl_impl, identity, *, rand(ComplexF64), rand(ComplexF64,3,4))
267+
test_rrule(mapfoldl_impl, identity, +, rand(ComplexF64), rand(ComplexF64,7))
268+
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
258269
end
259270
@testset "foldl(f, ::Tuple)" begin
260271
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
272+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
261273
@test y1 == 6
262-
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))
274+
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)
263275

264-
y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
276+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
265277
@test y2 == 0
266-
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))
278+
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)
267279

268280
# Finite differencing
269-
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
270-
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
281+
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
282+
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
271283
end
272284
end
273285

0 commit comments

Comments
 (0)