Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using IRTools
using MacroTools
using MacroTools: @forward

import ChainRules
import Distributed: pmap, CachingPool, workers
export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
export rrule_via_ad
Expand Down
27 changes: 16 additions & 11 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
using Base: RefValue
using Base: ismutabletype
using Base: RefValue, ismutabletype, tail

struct AccThunk{T} <: AbstractThunk
thunks::Vector{T}
end

_maybe_thunks(x) = x
_maybe_thunks(x::AccThunk) = x.thunks

ChainRules.unthunk(acc::AccThunk) = reduce(accum, unthunk.(acc.thunks))

# Interfaces

accum() = nothing
accum(x) = x

accum(x, y) =
x === nothing ? y :
y === nothing ? x :
x + y
x === nothing ?
y : (y === nothing ? x : x + y)

accum(x, y, zs...) = accum(accum(x, y), zs...)

Expand All @@ -35,9 +42,9 @@ accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x
accum(x::Nothing, y::AbstractThunk) = y
accum(x::AbstractThunk, y::Nothing) = x

accum(x, y::AbstractThunk) = accum(x, unthunk(y))
accum(x::AbstractThunk, y) = accum(unthunk(x), y)
accum(x::AbstractThunk, y::AbstractThunk) = accum(unthunk(x), unthunk(y))
accum(x::AbstractThunk, y::AbstractThunk) = AccThunk(vcat(_maybe_thunks(x), _maybe_thunks(y)))
accum(x::AbstractThunk, y) = AccThunk(vcat(_maybe_thunks(x), [y]))
accum(x, y::AbstractThunk) = AccThunk(vcat([x], _maybe_thunks(y)))

# Core functions
@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
Expand All @@ -55,7 +62,7 @@ accum_param(::Context{false}, _, Δ) = Δ
isbitstype(x) && return :(Δ)
quote
if haskey(cache(cx), x)
cache(cx)[x] = accum(cache(cx)[x],Δ)
cache(cx)[x] = accum(cache(cx)[x], Δ)
return
else
return Δ
Expand Down Expand Up @@ -96,8 +103,6 @@ end

# Tuples

using Base: tail

@_adjoint_keepthunks tuple(xs...) = xs, identity

@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
Expand Down
12 changes: 12 additions & 0 deletions test/features_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,4 +882,16 @@ end
@test g1[1] ≈ g2[1] ≈ g3[1]
end

# https://github.com/FluxML/Flux.jl/issues/2585
@testset "Nested thunks" begin
W = ones(Float32, 10, 10)
x = [ones(Float32, 10) for i in 1:512]
gs = gradient(W) do W
sum((W * xi)[1] for xi in x)
end
dW = gs[1]
@test all(dW[1, :] .== 512)
@test all(dW[2:end, :] .== 0)
end

end
Loading