Skip to content

Commit d6f6a92

Browse files
committed
simple rule for mapfoldl
1 parent 2b5877c commit d6f6a92

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,25 @@ end
417417
end
418418

419419
#####
420+
##### `mapfoldl(f, g, ::Tuple)`
421+
#####
422+
423+
# For tuples there should be no harm in handling `map` first.
424+
# This will also catch `mapreduce`.
425+
426+
function rrule(
427+
cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple;
428+
) where {F,G}
429+
y, backmap = rrule(cfg, map, f, x)
430+
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
431+
function mapfoldl_pullback_tuple(dz)
432+
_, _, dop, dinit, dy = backred(dz)
433+
_, df, dx = backmap(dy)
434+
return (NoTangent(), df, dop, dinit, dx)
435+
end
436+
return z, mapfoldl_pullback_tuple
437+
end
438+
420439
#####
421440
##### `foldl(f, ::Tuple)`
422441
#####

test/rulesets/Base/mapreduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ const _INIT = Base._InitialValue()
304304
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
305305
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
306306
end
307+
@testset "mapfoldl(f, g, ::Tuple)" begin
308+
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)
309+
test_rrule(mapfoldl_impl, abs2, *, 1+rand(), Tuple(rand(ComplexF64, 5)), check_inferred=false)
310+
# TODO make the `map(f, ::Tuple)` rule infer better!
311+
end
307312
end
308313

309314
@testset "Accumulations" begin

0 commit comments

Comments
 (0)