Skip to content

Commit 48647d9

Browse files
DhairyaLGandhiCarloLucibellooxinabox
authored
Accumulate mutable structs with Ref (#1574)
* fix: accumulate mutable structs properly * chore: deref RefValue while accumulating * chore: restrict for NT only * test: mapping with differing graphs Co-authored-by: Aayush https://github.com/AayushSabharwal * fix: avoid double accumulation * fix: check compatible trees * chore: correct fieldnames usage * Update src/lib/lib.jl Co-authored-by: Frames White <[email protected]> * Update src/lib/lib.jl Co-authored-by: Frames White <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Frames White <[email protected]>
1 parent 6dc5d2f commit 48647d9

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/lib/lib.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ function accum(x::RefValue, y::RefValue)
2828
@assert x === y
2929
return x
3030
end
31+
function accum(x::NamedTuple, ref::RefValue)
32+
# We do not actually do any accumulation here, because the ref will already have been mutated.
33+
fieldnames(typeof(ref[])) fieldnames(typeof(x)) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys"))
34+
ref
35+
end
36+
function accum(ref::RefValue, x::NamedTuple)
37+
# We do not actually do any accumulation here, because the ref will already have been mutated.
38+
fieldnames(typeof(x)) fieldnames(typeof(ref[])) || throw(ArgumentError("$x keys must be a subset of $(ref[]) keys from Ref"))
39+
ref
40+
end
3141

3242
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
3343
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)

test/features_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,19 @@ let
480480
@test back(1.) == ((1.0,),)
481481
end
482482

483+
mutable struct MWEGetter{G, U}
484+
idxs::G
485+
u::U
486+
end
487+
488+
u = ones(3)
489+
idxs = [1, 2]
490+
mwe1574 = MWEGetter(idxs, u)
491+
492+
function fn1574(mwe)
493+
map(i -> mwe.u[i], mwe.idxs)
494+
end
495+
483496
@testset "mutable struct, including Ref" begin
484497
# Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the
485498
# map to ChainRules types and back normalises to (value = 7.0,) same as struct:
@@ -506,6 +519,9 @@ end
506519
# Broadcasting over Ref is handled specially. Tested elsewhere too.
507520
@test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
508521
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
522+
523+
# Broadcasting/ Mapping over Mutables with differring graphs for fields
524+
@test gradient(x -> sum(fn1574(x)), mwe1574)[1] == (; idxs = [nothing, nothing], u = [1.0, 1.0, 0.0])
509525
end
510526

511527
@testset "mutable accum_param bugs" begin

0 commit comments

Comments
 (0)