Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,24 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
end

"""
_project(x, dx)
project(x, dx)

Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape.
Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
function project(x, dx)
map(_project, x, dx)
end

project(x, ::Union{AbstractZero,Nothing}) = nothing

@inline function _project(x, dx)
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end

_project(x, ::Union{AbstractZero,Nothing}) = NoTangent()

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

Expand All @@ -148,13 +156,13 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)
(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()

# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
(proj::ProjectTo{<:Complex})(dx::Tangent) = proj(Complex(dx.re, dx.im))

# CRC likes Tangent{AbstractArray}, but Zygote makes Tangent{Any}
# in particular this would hit https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2ec2549b73b22bc08f554dae864fb650cfb9c3d7/src/projection.jl#L139
# if we were not losing track of the Primal in the Tangent
# This type piracy is just giving up that safety check.
(project::ProjectTo{AbstractArray})(dx::Tangent) = dx
(proj::ProjectTo{AbstractArray})(dx::Tangent) = dx

"""
ZBack{F}(back) <: Function
Expand Down
14 changes: 4 additions & 10 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function pullback(f, args...)
end

sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
# sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

Expand All @@ -57,7 +57,7 @@ the derivative (for scalar `x`) or the gradient.
See also [`withgradient`](@ref) to keep the value `f(args...)`,
and [`pullback`](@ref) for value and back-propagator.

```jldoctest; setup=:(using Zygote)
```jldoctest; setup = :(using Zygote)
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

Expand All @@ -74,14 +74,9 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons
y, back = pullback(f, x)
back(sensitivity(y))[1]
end
Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!

"""
withgradient(f, args...)
Expand All @@ -101,8 +96,7 @@ true
function withgradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
(val = y, grad = grad)
end

# Param-style wrappers
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ branchfor(ir, (from,to)) =
get(filter(br -> br.block == to, branches(block(ir, from))), 1, nothing)

xaccum(ir) = nothing
xaccum(ir, x) = x
# xaccum(ir, x) = x
xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...))

function adjoint(pr::Primal)
Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (_project(x, dx), map(_->nothing, inds)...)
return (dx, map(_ -> nothing, inds)...)
end

"""
Expand Down
17 changes: 9 additions & 8 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,20 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)

function unbroadcast(x::AbstractArray, x̄)
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
else
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
_project(x, accum_sum(x̄; dims = dims))
end
size(x) == size(x̄) ? x̄ :
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
end

unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Base.RefValue, x̄) = (x = accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims = 2:ndims(x̄))) # case length(x) > 1

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand Down
9 changes: 8 additions & 1 deletion src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ end

accum() = nothing
accum(x) = x
# accum(x::AbstractArray{<:Complex}) = real(x)

accum(x, y) =
x === nothing ? y :
Expand All @@ -22,8 +23,14 @@ accum(x, y) =
accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
# use promotion rules for T, S...; x needs to be the widest type
function accum(x::T, ys::AbstractArray...) where {T <: AbstractArray}
accum.(convert.(T, (x, ys...))...)
end

# function accum(x::AbstractArray, ys::AbstractArray{<:Complex}...)
# accum.(real(x), real.(ys)...)
# end
@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
fieldnames(y) ⊆ fieldnames(x) || throw(ArgumentError("$y keys must be a subset of $x keys"))
Expand Down
46 changes: 35 additions & 11 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,25 @@ end
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
end

@testset "Projection" begin
# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
@test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback
# structure restoration:
@test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix
@test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
adj_gs = gradient(x -> sum(sqrt.(x)), a_gpu')
adj_p = Zygote.project((a_gpu',), adj_gs)
@test adj_p[1] isa Adjoint # previously a matrix

diag_gs = gradient(x -> sum(exp.(x)), Diagonal(a_gpu))
diag_p = Zygote.project((Diagonal(a_gpu),), diag_gs)
@test diag_p isa Diagonal

# non-differentiables
@test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing
gs = gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)
p = Zygote.project((a_gpu, a_gpu .> 0), gs)
@test p[2] === nothing
end

@testset "sum(f, x)" begin
Expand All @@ -65,11 +75,18 @@ end
g2_gpu = gradient(f2, a_gpu)[1]
@test g2_gpu isa CuArray
@test g2_gpu |> collect ≈ g2
end

f3(x) = sum(y->y^3, x') # anonymous function
g3 = gradient(f3, a')[1]
g3_gpu = gradient(f3, a_gpu')[1]
@test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure
@testset "Projection: sums" begin
a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01]
a_gpu = a |> cu

f(x) = sum(y -> y ^ 3, x') # anonymous function
g3 = gradient(f, a')[1]
g3_gpu = gradient(f, a_gpu')[1]

g3_gpu_p = Zygote.project((a_gpu',), (g3,))
@test g3_gpu_p[1] isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure with projection
@test g3_gpu |> collect ≈ g3
end

Expand Down Expand Up @@ -133,10 +150,17 @@ end
grads = (cu(ones(Float32, 3)), 1.f0)
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads

@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32}
@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection
# Projection
@testset "Projection" begin
r = cu(rand(Float32, 3))

gs = gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))
gs_p = Zygote.project((r, 5.), gs)
@test gs_p[1] isa CUDA.CuArray{Float32}
@test gs_p[2] isa Float64 # projection

@test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
@test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
@test_skip gradient((x, y...) -> sum(vcat(x, y...)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end
end