diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 111278496..59bb11a39 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -28,6 +28,16 @@ function accum(x::RefValue, y::RefValue) @assert x === y return x end +function accum(x::NamedTuple, ref::RefValue) + # We do not actually do any accumulation here, because the ref will already have been mutated. + fieldnames(typeof(ref[])) ⊆ fieldnames(typeof(x)) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys")) + ref +end +function accum(ref::RefValue, x::NamedTuple) + # We do not actually do any accumulation here, because the ref will already have been mutated. + fieldnames(typeof(x)) ⊆ fieldnames(typeof(ref[])) || throw(ArgumentError("$x keys must be a subset of $(ref[]) keys from Ref")) + ref +end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) diff --git a/test/features_tests.jl b/test/features_tests.jl index a81a9d4be..12aca7ae3 100644 --- a/test/features_tests.jl +++ b/test/features_tests.jl @@ -480,6 +480,19 @@ let @test back(1.) == ((1.0,),) end +mutable struct MWEGetter{G, U} + idxs::G + u::U +end + +u = ones(3) +idxs = [1, 2] +mwe1574 = MWEGetter(idxs, u) + +function fn1574(mwe) + map(i -> mwe.u[i], mwe.idxs) +end + @testset "mutable struct, including Ref" begin # Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the # map to ChainRules types and back normalises to (value = 7.0,) same as struct: @@ -506,6 +519,9 @@ end # Broadcasting over Ref is handled specially. Tested elsewhere too. @test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),) @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) + + # Broadcasting/ Mapping over Mutables with differring graphs for fields + @test gradient(x -> sum(fn1574(x)), mwe1574)[1] == (; idxs = [nothing, nothing], u = [1.0, 1.0, 0.0]) end @testset "mutable accum_param bugs" begin