From 6fa20387515c5d6bcbe7de234245f11bc039ae91 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 13 Mar 2025 01:39:09 +0200 Subject: [PATCH] Add AccThunk --- src/Zygote.jl | 1 + src/lib/lib.jl | 27 ++++++++++++++++----------- test/features_tests.jl | 12 ++++++++++++ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/Zygote.jl b/src/Zygote.jl index 7f5ffa9c2..4f2ebd5b3 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -13,6 +13,7 @@ using IRTools using MacroTools using MacroTools: @forward +import ChainRules import Distributed: pmap, CachingPool, workers export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint export rrule_via_ad diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 111278496..e9a393b8c 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -1,5 +1,13 @@ -using Base: RefValue -using Base: ismutabletype +using Base: RefValue, ismutabletype, tail + +struct AccThunk{T} <: AbstractThunk + thunks::Vector{T} +end + +_maybe_thunks(x) = x +_maybe_thunks(x::AccThunk) = x.thunks + +ChainRules.unthunk(acc::AccThunk) = reduce(accum, unthunk.(acc.thunks)) # Interfaces @@ -7,9 +15,8 @@ accum() = nothing accum(x) = x accum(x, y) = - x === nothing ? y : - y === nothing ? x : - x + y + x === nothing ? + y : (y === nothing ? x : x + y) accum(x, y, zs...) = accum(accum(x, y), zs...) @@ -35,9 +42,9 @@ accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x accum(x::Nothing, y::AbstractThunk) = y accum(x::AbstractThunk, y::Nothing) = x -accum(x, y::AbstractThunk) = accum(x, unthunk(y)) -accum(x::AbstractThunk, y) = accum(unthunk(x), y) -accum(x::AbstractThunk, y::AbstractThunk) = accum(unthunk(x), unthunk(y)) +accum(x::AbstractThunk, y::AbstractThunk) = AccThunk(vcat(_maybe_thunks(x), _maybe_thunks(y))) +accum(x::AbstractThunk, y) = AccThunk(vcat(_maybe_thunks(x), [y])) +accum(x, y::AbstractThunk) = AccThunk(vcat([x], _maybe_thunks(y))) # Core functions @_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) @@ -55,7 +62,7 @@ accum_param(::Context{false}, _, Δ) = Δ isbitstype(x) && return :(Δ) quote if haskey(cache(cx), x) - cache(cx)[x] = accum(cache(cx)[x],Δ) + cache(cx)[x] = accum(cache(cx)[x], Δ) return else return Δ @@ -96,8 +103,6 @@ end # Tuples -using Base: tail - @_adjoint_keepthunks tuple(xs...) = xs, identity @_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} diff --git a/test/features_tests.jl b/test/features_tests.jl index a81a9d4be..3b264e9da 100644 --- a/test/features_tests.jl +++ b/test/features_tests.jl @@ -882,4 +882,16 @@ end @test g1[1] ≈ g2[1] ≈ g3[1] end +# https://github.com/FluxML/Flux.jl/issues/2585 +@testset "Nested thunks" begin + W = ones(Float32, 10, 10) + x = [ones(Float32, 10) for i in 1:512] + gs = gradient(W) do W + sum((W * xi)[1] for xi in x) + end + dW = gs[1] + @test all(dW[1, :] .== 512) + @test all(dW[2:end, :] .== 0) +end + end