From d8963906c8f7d0b958cd029a9b868236b84ed735 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Sat, 24 May 2025 04:08:13 +0530 Subject: [PATCH 1/9] fix: accumulate mutable structs properly --- src/lib/lib.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 111278496..3e3b0e2ce 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -15,6 +15,7 @@ accum(x, y, zs...) = accum(accum(x, y), zs...) accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...) +accum(::Tuple{}, ::NamedTuple{}) = () @generated function accum(x::NamedTuple, y::NamedTuple) # assumes that y has no keys apart from those also in x @@ -230,7 +231,7 @@ end else dx = grad_mut(__context__, x) dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) - return (dx,nothing) + return (dx[],nothing) end end unwrap(val), back From 98f9817b2fe7b606371ddbd8f8fc19e0a0b30a2d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 01:27:18 +0530 Subject: [PATCH 2/9] chore: deref RefValue while accumulating --- src/lib/lib.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 3e3b0e2ce..961f2df9b 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -15,7 +15,6 @@ accum(x, y, zs...) = accum(accum(x, y), zs...) accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...) -accum(::Tuple{}, ::NamedTuple{}) = () @generated function accum(x::NamedTuple, y::NamedTuple) # assumes that y has no keys apart from those also in x @@ -29,6 +28,8 @@ function accum(x::RefValue, y::RefValue) @assert x === y return x end +accum(x, ref::RefValue) = accum(x, ref[]) +accum(ref::RefValue, x) = accum(ref[], x) accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) @@ -231,7 +232,7 @@ end else dx = grad_mut(__context__, x) dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) - return (dx[],nothing) + return (dx,nothing) end end unwrap(val), back From 49d93b6c1f855ef83e50f2785c6ca2b79d594f47 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 02:05:35 +0530 Subject: [PATCH 3/9] chore: restrict for NT only --- src/lib/lib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 961f2df9b..96dc6487a 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -28,8 +28,8 @@ function accum(x::RefValue, y::RefValue) @assert x === y return x end -accum(x, ref::RefValue) = accum(x, ref[]) -accum(ref::RefValue, x) = accum(ref[], x) +accum(x::NamedTuple, ref::RefValue) = accum(x, ref[]) +accum(ref::RefValue, x::NamedTuple) = accum(ref[], x) accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) From f5a727eeecdee3a2583dfd23c732267d4e4a7a53 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 18:51:06 +0530 Subject: [PATCH 4/9] test: mapping with differing graphs Co-authored-by: Aayush https://github.com/AayushSabharwal --- test/features_tests.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/features_tests.jl b/test/features_tests.jl index a81a9d4be..12aca7ae3 100644 --- a/test/features_tests.jl +++ b/test/features_tests.jl @@ -480,6 +480,19 @@ let @test back(1.) == ((1.0,),) end +mutable struct MWEGetter{G, U} + idxs::G + u::U +end + +u = ones(3) +idxs = [1, 2] +mwe1574 = MWEGetter(idxs, u) + +function fn1574(mwe) + map(i -> mwe.u[i], mwe.idxs) +end + @testset "mutable struct, including Ref" begin # Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the # map to ChainRules types and back normalises to (value = 7.0,) same as struct: @@ -506,6 +519,9 @@ end # Broadcasting over Ref is handled specially. Tested elsewhere too. @test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),) @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) + + # Broadcasting/ Mapping over Mutables with differring graphs for fields + @test gradient(x -> sum(fn1574(x)), mwe1574)[1] == (; idxs = [nothing, nothing], u = [1.0, 1.0, 0.0]) end @testset "mutable accum_param bugs" begin From 531aa34a85b5bd2a5130e624c64ec987d2de3536 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 18:52:17 +0530 Subject: [PATCH 5/9] fix: avoid double accumulation --- src/lib/lib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 96dc6487a..5079f35cf 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -28,8 +28,8 @@ function accum(x::RefValue, y::RefValue) @assert x === y return x end -accum(x::NamedTuple, ref::RefValue) = accum(x, ref[]) -accum(ref::RefValue, x::NamedTuple) = accum(ref[], x) +accum(x::NamedTuple, ref::RefValue) = ref +accum(ref::RefValue, x::NamedTuple) = ref accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) From c35c9fe50e87074b66d7fe876cabc95fb34f2082 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 19:17:54 +0530 Subject: [PATCH 6/9] fix: check compatible trees --- src/lib/lib.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 5079f35cf..9d5e692e3 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -28,8 +28,14 @@ function accum(x::RefValue, y::RefValue) @assert x === y return x end -accum(x::NamedTuple, ref::RefValue) = ref -accum(ref::RefValue, x::NamedTuple) = ref +function accum(x::NamedTuple, ref::RefValue) + fieldnames(ref[]) ⊆ fieldnames(x) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys")) + ref +end +function accum(ref::RefValue, x::NamedTuple) + fieldnames(x) ⊆ fieldnames(ref[]) || throw(ArgumentError("$x keys from Ref must be a subset of $(ref[]) keys")) + ref +end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) From c732724c52c206ac209099e2ad37d6399770173d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 27 May 2025 20:03:41 +0530 Subject: [PATCH 7/9] chore: correct fieldnames usage --- src/lib/lib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 9d5e692e3..8f82d637d 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -29,11 +29,11 @@ function accum(x::RefValue, y::RefValue) return x end function accum(x::NamedTuple, ref::RefValue) - fieldnames(ref[]) ⊆ fieldnames(x) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys")) + fieldnames(typeof(ref[])) ⊆ fieldnames(typeof(x)) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys")) ref end function accum(ref::RefValue, x::NamedTuple) - fieldnames(x) ⊆ fieldnames(ref[]) || throw(ArgumentError("$x keys from Ref must be a subset of $(ref[]) keys")) + fieldnames(typeof(x)) ⊆ fieldnames(typeof(ref[])) || throw(ArgumentError("$x keys must be a subset of $(ref[]) keys from Ref")) ref end From c4094e712e21f706504e6b41a508b95fb0f4da5a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jun 2025 09:30:16 +0200 Subject: [PATCH 8/9] Update src/lib/lib.jl Co-authored-by: Frames White --- src/lib/lib.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 8f82d637d..db37ccae8 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -33,6 +33,7 @@ function accum(x::NamedTuple, ref::RefValue) ref end function accum(ref::RefValue, x::NamedTuple) + # We do not actually do any accumulation here, because the ref will already have been mutated. fieldnames(typeof(x)) ⊆ fieldnames(typeof(ref[])) || throw(ArgumentError("$x keys must be a subset of $(ref[]) keys from Ref")) ref end From 999a335b88582794fd85ff457bf6be63eb986da7 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jun 2025 09:30:25 +0200 Subject: [PATCH 9/9] Update src/lib/lib.jl Co-authored-by: Frames White --- src/lib/lib.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index db37ccae8..59bb11a39 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -29,6 +29,7 @@ function accum(x::RefValue, y::RefValue) return x end function accum(x::NamedTuple, ref::RefValue) + # We do not actually do any accumulation here, because the ref will already have been mutated. fieldnames(typeof(ref[])) ⊆ fieldnames(typeof(x)) || throw(ArgumentError("$(ref[]) keys from Ref must be a subset of $x keys")) ref end