Skip to content

Commit ef95dd5

Browse files
committed
minimal change foldl -> mapfoldl_impl
1 parent 13ccc86 commit ef95dd5

File tree

2 files changed

+46
-31
lines changed

2 files changed

+46
-31
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,12 @@ end
427427
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
428428
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
429429

430+
# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
431+
430432
function rrule(
431-
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
432-
init=_InitialValue()
433+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
433434
) where {G}
434-
list, start = if init === _InitialValue()
435+
list, start = if init === _INIT
435436
_drop1(x), first(x)
436437
else
437438
# Case with init keyword is simpler to understand first!
@@ -455,11 +456,12 @@ function rrule(
455456
end
456457
dop = sum(first, trio)
457458
dx = map(last, _reverse1(trio))
458-
if init === _InitialValue()
459+
if init === _INIT
459460
# `hobbits` is one short
460461
dx = _vcat1(trio[end][2], dx)
461462
end
462-
return (NoTangent(), dop, project(_reshape1(dx, axe)))
463+
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
464+
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
463465
end
464466
return y, unfoldl
465467
end
@@ -484,7 +486,8 @@ _reverse1(x::Tuple) = reverse(x)
484486
_drop1(x::Tuple) = Base.tail(x)
485487
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
486488

487-
struct _InitialValue end # Old versions don't have `Base._InitialValue`
489+
# struct _InitialValue end # Old versions don't have `Base._InitialValue`
490+
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
488491

489492
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
490493
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
@@ -505,15 +508,15 @@ _no_tuple_tangent(dx) = dx
505508

506509
function rrule(
507510
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
508-
init=_InitialValue(), dims=nothing
511+
init=_INIT, dims=nothing
509512
) where {G}
510513
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
511514
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
512515
# It's not supported by AD either, so no point calling back, and no regression:
513516
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
514517
# ERROR: Mutating arrays is not supported
515518
)
516-
list, start = if init === _InitialValue()
519+
list, start = if init === _INIT
517520
_drop1(x), first(x)
518521
else
519522
x, init
@@ -522,7 +525,7 @@ function rrule(
522525
c, back = rrule_via_ad(config, op, a, b)
523526
end
524527
y = map(first, hobbits)
525-
if init === _InitialValue()
528+
if init === _INIT
526529
# `hobbits` is one short, and first one doesn't invoke `op`
527530
y = _vcat1(first(x), y)
528531
end
@@ -543,7 +546,7 @@ function rrule(
543546
end
544547
dop = sum(first, trio)
545548
dx = map(last, _reverse1(trio))
546-
if init == _InitialValue()
549+
if init == _INIT
547550
# `hobbits` is one short, and the first one is weird
548551
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
549552
end

test/rulesets/Base/mapreduce.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -213,60 +213,72 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
213213
end # prod
214214

215215
@testset "foldl(f, ::Array)" begin
216+
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
217+
# now attached there, as this is the simplest way to handle `init` keyword.
218+
@eval using Base: mapfoldl_impl
219+
@eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
220+
216221
# Simple
217-
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
222+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
218223
@test y1 == 6
219-
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])
224+
@test b1(7)[1:3] == (NoTangent(), NoTangent(), NoTangent())
225+
@test b1(7)[4] isa ChainRulesCore.NotImplemented
226+
@test b1(7)[5] == [42, 21, 14]
220227

221-
y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
228+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, [1 2; 0 4]) # without init, needs vcat
222229
@test y2 == 0
223-
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape
230+
@test b2(8)[5] == [0 0; 64 0] # matrix, needs reshape
224231

225232
# Test execution order
226233
c5 = Counter()
227-
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
234+
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, [5, 7, 11])
228235
@test c5 == Counter(2)
229236
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
230-
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
237+
@test b5(1)[5] == [12*32, 12*42, 22]
231238
@test c5 == Counter(42)
232239

233240
c6 = Counter()
234-
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
241+
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, [5, 7, 11])
235242
@test c6 == Counter(3)
236243
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
237-
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
244+
@test b6(1)[5] == [63*33*13, 43*13, 23]
238245
@test c6 == Counter(63)
239246

240247
# Test gradient of function
241-
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
248+
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, [5, 7, 11])
242249
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
243-
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])
250+
b7_1 = b7(1)
251+
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
252+
@test b7_1[5] == [693, 495, 315]
244253

245-
y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
254+
y8, b8 = rrule(CFG, mapfoldl_impl, identity, Multiplier(13), 3, [5, 7, 11])
246255
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
247-
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
256+
b8_1 = b8(1)
257+
@test b8_1[3] == Tangent{Multiplier{Int}}(x = 585585,)
258+
@test b8_1[5] == [507507, 362505, 230685]
248259
# To find these numbers:
249260
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
250261
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
251262

252263
# Finite differencing
253-
test_rrule(foldl, /, 1 .+ rand(3,4))
254-
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
255-
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
256-
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
264+
test_rrule(mapfoldl_impl, identity, /, _INIT, 1 .+ rand(3,4))
265+
test_rrule(mapfoldl_impl, identity, *, rand(ComplexF64), rand(ComplexF64,3,4))
266+
test_rrule(mapfoldl_impl, identity, +, rand(ComplexF64), rand(ComplexF64,7))
267+
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
257268
end
258269
@testset "foldl(f, ::Tuple)" begin
259270
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
271+
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
260272
@test y1 == 6
261-
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))
273+
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)
262274

263-
y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
275+
y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
264276
@test y2 == 0
265-
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))
277+
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)
266278

267279
# Finite differencing
268-
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
269-
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
280+
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
281+
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
270282
end
271283
end
272284

0 commit comments

Comments
 (0)