Skip to content

Commit 9af7a64

Browse files
committed
simple rule for mapfoldl
1 parent 4c6433a commit 9af7a64

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515

1616
[compat]
17-
ChainRulesCore = "1.12"
17+
ChainRulesCore = "1.15.3"
1818
ChainRulesTestUtils = "1.5"
1919
Compat = "3.42.0, 4"
2020
FiniteDifferences = "0.12.20"

src/rulesets/Base/mapreduce.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,25 @@ end
417417
end
418418

419419
#####
420+
##### `mapfoldl(f, g, ::Tuple)`
421+
#####
422+
423+
# For tuples there should be no harm in handling `map` first.
424+
# This will also catch `mapreduce`.
425+
426+
function rrule(
427+
cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple;
428+
) where {F,G}
429+
y, backmap = rrule(cfg, map, f, x)
430+
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
431+
function mapfoldl_pullback_tuple(dz)
432+
_, _, dop, dinit, dy = backred(dz)
433+
_, df, dx = backmap(dy)
434+
return (NoTangent(), df, dop, dinit, dx)
435+
end
436+
return z, mapfoldl_pullback_tuple
437+
end
438+
420439
#####
421440
##### `foldl(f, ::Tuple)`
422441
#####

test/rulesets/Base/mapreduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ const _INIT = Base._InitialValue()
303303
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
304304
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
305305
end
306+
@testset "mapfoldl(f, g, ::Tuple)" begin
307+
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)
308+
test_rrule(mapfoldl_impl, abs2, *, 1+rand(), Tuple(rand(ComplexF64, 5)), check_inferred=false)
309+
# TODO make the `map(f, ::Tuple)` rule infer better!
310+
end
306311
end
307312

308313
@testset "Accumulations" begin

0 commit comments

Comments
 (0)