Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ function accum(x::RefValue, y::RefValue)
@assert x === y
return x
end
function accum(x::NamedTuple, ref::RefValue)
fieldnames(typeof(ref[])) ⊆ fieldnames(typeof(x)) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys"))
ref
end
function accum(ref::RefValue, x::NamedTuple)
fieldnames(typeof(x)) ⊆ fieldnames(typeof(ref[])) || throw(ArgumentError("$x keys must be a subset of $(ref[]) keys from Ref"))
ref
end

accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)
Expand Down
16 changes: 16 additions & 0 deletions test/features_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,19 @@ let
@test back(1.) == ((1.0,),)
end

mutable struct MWEGetter{G, U}
idxs::G
u::U
end

u = ones(3)
idxs = [1, 2]
mwe1574 = MWEGetter(idxs, u)

function fn1574(mwe)
map(i -> mwe.u[i], mwe.idxs)
end

@testset "mutable struct, including Ref" begin
# Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the
# map to ChainRules types and back normalises to (value = 7.0,) same as struct:
Expand All @@ -506,6 +519,9 @@ end
# Broadcasting over Ref is handled specially. Tested elsewhere too.
@test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)

# Broadcasting/ Mapping over Mutables with differring graphs for fields
@test gradient(x -> sum(fn1574(x)), mwe1574)[1] == (; idxs = [nothing, nothing], u = [1.0, 1.0, 0.0])
end

@testset "mutable accum_param bugs" begin
Expand Down
Loading