@@ -417,17 +417,73 @@ end
417
417
end
418
418
419
419
# ####
420
- # #### `foldl`
420
+ # ####
421
+ # #### `foldl(f, ::Tuple)`
421
422
# ####
422
423
423
424
# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
424
- # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
425
+ # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
426
+
427
+ # The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
428
+ # which is handled below. For tuples, `reduce` also comes here.
429
+
430
+ function rrule (
431
+ config:: RuleConfig{>:HasReverseMode} ,
432
+ :: typeof (Base. mapfoldl_impl),
433
+ :: typeof (identity),
434
+ op:: G ,
435
+ init:: Base._InitialValue ,
436
+ x:: Tuple ;
437
+ ) where {G}
438
+ hobbits = accumulate (Base. tail (x); init= (first (x), nothing )) do (a, _), b
439
+ # Here `a` is what we would normally cary forward, and `_` ignores
440
+ # the previous iteration's pullback function (needed later),
441
+ # while `b` is the fresh input from `list` as usual.
442
+ c, back = rrule_via_ad (config, op, a, b)
443
+ # We don't really need to store every `c`, last one is `foldl` output.
444
+ # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
445
+ end
446
+ y = first (last (hobbits))
447
+ project = ProjectTo (x)
448
+ function foldl_pullback_tuple (dy)
449
+ trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
450
+ ds, da, db = back (dc)
451
+ # Don't need to store every `da`, need one for the next iteration + the last.
452
+ end
453
+ dop = sum (first, trio)
454
+ dx = (trio[end ][2 ], reverse (map (last, trio))... )
455
+ return (NoTangent (), NoTangent (), ProjectTo (op)(dop), NoTangent (), project (dx))
456
+ end
457
+ return y, foldl_pullback_tuple
458
+ end
459
+
460
+ function rrule (
461
+ config:: RuleConfig{>:HasReverseMode} ,
462
+ :: typeof (Base. mapfoldl_impl),
463
+ :: typeof (identity),
464
+ op:: G ,
465
+ init,
466
+ x:: Tuple ;
467
+ ) where {G}
468
+ # Treat `init` by simply appending it to the `x`:
469
+ y, back = rrule (config, Base. mapfoldl_impl, identity, op, Base. _InitialValue (), (init, x... ))
470
+ project_x = ProjectTo (x)
471
+ project_in = ProjectTo (init)
472
+ function foldl_pullback_tuple_init (dy)
473
+ _, _, dop, _, dxplus = back (dy)
474
+ return (NoTangent (), NoTangent (), dop, project_in (first (dxplus)), project_x (Base. tail (dxplus)))
475
+ end
476
+ return y, foldl_pullback_tuple_init
477
+ end
425
478
426
- # The implementation aims to be efficient for both tuples and arrays, although using accumulate
427
- # to carry intermediate results along creates arrays of tuples which could be avoided; using a
428
- # loop can be a few times faster. Note also that it does not return a gradient for `init`.
479
+ # ####
480
+ # #### `foldl(f, ::Array)`
481
+ # ####
429
482
430
- # Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
483
+ # The implementation was originally for both tuples and arrays, although using accumulate
484
+ # to carry intermediate results along creates arrays of tuples which could be avoided.
485
+ # Using a loop can be a few times faster, this should be replaced.
486
+ # Note also that it does not return a gradient for `init`.
431
487
432
488
function rrule (
433
489
config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
@@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x)
486
542
_drop1 (x:: Tuple ) = Base. tail (x)
487
543
_zip2 (x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple (i -> (x[i],y[i]), N)
488
544
489
- # struct _InitialValue end # Old versions don't have `Base._InitialValue`
490
- const _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
545
+ const _INIT = Base. _InitialValue ()
491
546
492
547
_vcat1 (x, ys:: AbstractVector ) = vcat (x, ys)
493
548
_vcat1 (x:: AbstractArray , ys:: AbstractVector ) = vcat ([x], ys)
0 commit comments