Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
ZygoteRules = "0.2.1"
ZygoteRules = "0.2.3"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
literal_getproperty, literal_getfield, maybe_final, @adjoint_final

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
Expand Down
38 changes: 29 additions & 9 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForw
end
ZygoteRuleConfig() = ZygoteRuleConfig(Context())

@inline only_once(::Type{<:AContext}) = false # can't directly use Context{true,true} as not defined yet

_is_rrule_redispatcher(m::Method) = m.sig == Tuple{typeof(rrule), RuleConfig, Vararg}

Expand Down Expand Up @@ -195,17 +196,26 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)
(project::ProjectTo{AbstractArray})(dx::Tangent) = dx

"""
ZBack{F}(back) <: Function
ZBack{Y,F}(y, back) <: Function

Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions.
(A functor here is used rather than a closure to avoid boxing issues);
Now captures the forward result to call `finalize(y)` when done, if `only_once` says this is safe.
"""
struct ZBack{F} <: Function
struct ZBack{Y,F} <: Function
fwd::Y
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
@inline (s::ZBack{Nothing})(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
@inline function (s::ZBack)(dy)
∇s = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
maybe_final(y)
∇s
end

# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack{Nothing})(::Nothing) = nothing
@inline (s::ZBack)(::Nothing) = nothing

"""
Expand All @@ -214,9 +224,11 @@ end
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
@inline function chain_rrule(config, f, args...)
# @inline function chain_rrule(config::ZygoteRuleConfig{Context{I,O}}, f::F, args...) where {I,O,F}
@inline function chain_rrule(config::ZygoteRuleConfig{C}, f::F, args...) where {C,F}
y, back = rrule(config, f, args...)
return y, ZBack(back)
free = only_once(C) ? y : nothing
return y, ZBack(free, back)
end


Expand All @@ -226,10 +238,12 @@ end
As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
@inline function chain_rrule_kw(config, kwf, kwargs, f, args...)
# @inline function chain_rrule_kw(config::ZygoteRuleConfig{Context{I,O}}, kwf, kwargs, f::F, args...) where {I,O,F}
@inline function chain_rrule_kw(config::ZygoteRuleConfig{C}, kwf, kwargs, f::F, args...) where {C,F}
y, back = rrule(config, f, args...; kwargs...)
free = only_once(C) ? y : nothing
function kw_zpullback(dy)
dxs = ZBack(back)(dy)
dxs = ZBack(free, back)(dy)
if dxs === nothing # if dxs is nothing, then all partiaols are nothing
# Zygote convention is a single nothing no mather how partials, if all are nothing
return nothing
Expand All @@ -240,7 +254,8 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
return y, kw_zpullback
end

function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
# function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{Context{I,O}}, f_args...; kwargs...) where {I,O}
function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{C}, f_args...; kwargs...) where {C}
# first check whether there is an `rrule` which handles this directly
direcct = rrule(config, f_args...; kwargs...)
direcct === nothing || return direcct
Expand All @@ -255,7 +270,12 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
_pullback(config.context, f_args...)
end

ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
free = only_once(C) ? y : nothing
function ad_pullback(Δ)
∇s = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
maybe_final(free)
∇s
end
return y, ad_pullback
end

Expand Down
113 changes: 82 additions & 31 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import Base.Broadcast: broadcasted, materialize!
# Internal container used to track accumulated gradients of mutable types (including params).
# Type param I ∈ (true, false) indicates whether implicit params are in use.
# By default, this should be false unless pullback(f, ::Params) is called.
mutable struct Context{I} <: AContext
# Type parameter O ∈ (true, false) indecates whether we know the reverse pass will be
# run at most once (e.g. within gradient), defaults to false (for pullback, and jacobain).
mutable struct Context{I,O} <: AContext
cache::Union{IdDict{Any,Any},Nothing}
end

Context() = Context{false}(nothing)
Context() = Context{false,false}(nothing)
Context{I}(cache=nothing) where {I} = Context{I,false}(cache)
Context{I,O}() where {I,O} = Context{I,O}(nothing)

cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache

@inline only_once(::Type{<:Context{true,true}}) = true

struct Pullback{S,T}
t::T
end
Expand Down Expand Up @@ -93,7 +99,9 @@ julia> gradient([7, 11], 0, 1) do x, y, d
```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
# Type parameters for Context are implicit=false, once=true
cx = Context{false,true}(nothing)
y, back = pullback(f, cx, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end
Expand All @@ -104,6 +112,21 @@ Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for le
back(sensitivity(y))[1]
end

# This is inserted into @adjoint_final by ZygoteRules
@inline maybe_final(::Context{false,true}, x) = maybe_final(x)
# The goal is to free CuArrays promptly.
@inline maybe_final(x::DenseArray) = finalize(x)

# Without an @adjoint rule for this, some hessian tests fail:
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_finalize_th), Nothing
# And if it in fact finalises, then other 2nd derivative tests fail. So do nothing:
@adjoint maybe_final(x) = nothing, _ -> nothing
@adjoint maybe_final(::Context{false,true}, x) = nothing, _ -> nothing

# Probably just for testing:
maybe_final(x::Vector) = resize!(x, 0)
maybe_final(x::Array{<:AbstractFloat}) = fill!(x, NaN)

"""
withgradient(f, args...)
withgradient(f, ::Params)
Expand All @@ -129,40 +152,16 @@ julia> res.grad[w]
```
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
# Type parameters for Context are implicit=false, once=true
cx = Context{false,true}()
y, back = pullback(f, cx, args...)
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
end

# Param-style wrappers

"""
gradient(() -> loss(), ps::Params) -> Grads

Gradient with implicit parameters. Takes a zero-argument function,
and returns a dictionary-like container, whose keys are arrays `x in ps`.

See also [`withgradient`](@ref) to keep the value `loss()`.

```jldoctest; setup=:(using Zygote)
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];

julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)

julia> g[x]
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0

julia> haskey(g, z) # only x and y are parameters
false
```
"""
gradient
# Param-style wrappers

"""
Params([A, B])
Expand Down Expand Up @@ -391,6 +390,58 @@ function pullback(f, ps::Params)
end
end

"""
gradient(() -> loss(), ps::Params) -> Grads

Gradient with implicit parameters. Takes a zero-argument function,
and returns a dictionary-like container, whose keys are arrays `x in ps`.

See also [`withgradient`](@ref) to keep the value `loss()`.

```jldoctest; setup=:(using Zygote)
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];

julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)

julia> g[x]
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0

julia> haskey(g, z) # only x and y are parameters
false
```
"""
function gradient(f, ps::Params)
y, back = pullback(f, ps)
back(sensitivity(y))
end

"""
withgradient(f, ps::Params) -> Grads

Returns both the value of the function and the [`gradient`](@ref),
as a named tuple.

```jldoctest; setup=:(using Zygote)
julia> w = [3.0];

julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))

julia> res.grad[w]
1-element Vector{Float64}:
6.0
```
"""
function withgradient(f, ps::Params)
y, back = pullback(f, ps)
(val=y, grad=back(sensitivity(y)))
end

# Code Reflection

function code_ir(f, T)
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ end
chain_rrule_f = :chain_rrule
end

# Here ZygoteRuleConfig{Zygote.Context{false, true}} is passed to chain_rrule

hascr, cr_edge = has_chain_rrule(cr_T)
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))

Expand Down
4 changes: 3 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
ys = map(first, ys_and_backs)
arg_ax = map(_tryaxes, args)
function map_back(Δ)
if Base.issingletontype(F) && length(args) == 1
∇s = if Base.issingletontype(F) && length(args) == 1
Δarg = $mapfunc(((_,pb), δ) -> last_or_nothing(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
elseif Base.issingletontype(F)
Expand All @@ -207,6 +207,8 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
Δargs = map(_restore, Δf_and_args[2:end], arg_ax)
(Δf, Δargs...)
end
maybe_final(cx, ys_and_backs)
∇s
end
map_back(::Nothing) = nothing
return ys, map_back
Expand Down
Loading