diff --git a/Project.toml b/Project.toml index 82d2f59f0..5668e7c69 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ MacroTools = "0.5" NaNMath = "0.3, 1" Requires = "1.1" SpecialFunctions = "1.6, 2" -ZygoteRules = "0.2.1" +ZygoteRules = "0.2.3" julia = "1.6" [extras] diff --git a/src/Zygote.jl b/src/Zygote.jl index c651968ba..1b0c54d69 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, - literal_getproperty, literal_getfield + literal_getproperty, literal_getfield, maybe_final, @adjoint_final using ChainRulesCore using ChainRules: ChainRules, rrule, unthunk, canonicalize diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7c7de8655..d8c7152be 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -3,6 +3,7 @@ struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForw end ZygoteRuleConfig() = ZygoteRuleConfig(Context()) +@inline only_once(::Type{<:AContext}) = false # can't directly use Context{true,true} as not defined yet _is_rrule_redispatcher(m::Method) = m.sig == Tuple{typeof(rrule), RuleConfig, Vararg} @@ -195,17 +196,26 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x) (project::ProjectTo{AbstractArray})(dx::Tangent) = dx """ - ZBack{F}(back) <: Function + ZBack{Y,F}(y, back) <: Function Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. (A functor here is used rather than a closure to avoid boxing issues); +Now captures the forward result to call `finalize(y)` when done, if `only_once` says this is safe. """ -struct ZBack{F} <: Function +struct ZBack{Y,F} <: Function + fwd::Y back::F end -@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) +@inline (s::ZBack{Nothing})(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) +@inline function (s::ZBack)(dy) + ∇s = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) + maybe_final(s.fwd) + ∇s +end + # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) +@inline (s::ZBack{Nothing})(::Nothing) = nothing @inline (s::ZBack)(::Nothing) = nothing """ @@ -214,9 +224,11 @@ end Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`. The pullback is appropriately wrapped up to follow Zygote conventions. """ -@inline function chain_rrule(config, f, args...) +# @inline function chain_rrule(config::ZygoteRuleConfig{Context{I,O}}, f::F, args...) where {I,O,F} +@inline function chain_rrule(config::ZygoteRuleConfig{C}, f::F, args...) where {C,F} y, back = rrule(config, f, args...) - return y, ZBack(back) + free = only_once(C) ? y : nothing + return y, ZBack(free, back) end @@ -226,10 +238,12 @@ end As per [`chain_rrule`](@ref) but with support for kwargs. `kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments. """ -@inline function chain_rrule_kw(config, kwf, kwargs, f, args...) +# @inline function chain_rrule_kw(config::ZygoteRuleConfig{Context{I,O}}, kwf, kwargs, f::F, args...) where {I,O,F} +@inline function chain_rrule_kw(config::ZygoteRuleConfig{C}, kwf, kwargs, f::F, args...) where {C,F} y, back = rrule(config, f, args...; kwargs...) + free = only_once(C) ? y : nothing function kw_zpullback(dy) - dxs = ZBack(back)(dy) + dxs = ZBack(free, back)(dy) if dxs === nothing # if dxs is nothing, then all partiaols are nothing # Zygote convention is a single nothing no mather how partials, if all are nothing return nothing @@ -240,7 +254,8 @@ As per [`chain_rrule`](@ref) but with support for kwargs. return y, kw_zpullback end -function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...) +# function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{Context{I,O}}, f_args...; kwargs...) where {I,O} +function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig{C}, f_args...; kwargs...) where {C} # first check whether there is an `rrule` which handles this directly direcct = rrule(config, f_args...; kwargs...) direcct === nothing || return direcct @@ -255,7 +270,12 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs _pullback(config.context, f_args...) end - ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args) + free = only_once(C) ? y : nothing + function ad_pullback(Δ) + ∇s = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args) + maybe_final(free) + ∇s + end return y, ad_pullback end diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index f350069f4..a9a7d8736 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -7,14 +7,20 @@ import Base.Broadcast: broadcasted, materialize! # Internal container used to track accumulated gradients of mutable types (including params). # Type param I ∈ (true, false) indicates whether implicit params are in use. # By default, this should be false unless pullback(f, ::Params) is called. -mutable struct Context{I} <: AContext +# Type parameter O ∈ (true, false) indecates whether we know the reverse pass will be +# run at most once (e.g. within gradient), defaults to false (for pullback, and jacobain). +mutable struct Context{I,O} <: AContext cache::Union{IdDict{Any,Any},Nothing} end -Context() = Context{false}(nothing) +Context() = Context{false,false}(nothing) +Context{I}(cache=nothing) where {I} = Context{I,false}(cache) +Context{I,O}() where {I,O} = Context{I,O}(nothing) cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache +@inline only_once(::Type{<:Context{<:Any,true}}) = true + struct Pullback{S,T} t::T end @@ -93,7 +99,9 @@ julia> gradient([7, 11], 0, 1) do x, y, d ``` """ function gradient(f, args...) - y, back = pullback(f, args...) + # Type parameters for Context are implicit=false, once=true + cx = Context{false,true}(nothing) + y, back = pullback(f, cx, args...) grad = back(sensitivity(y)) isnothing(grad) ? nothing : map(_project, args, grad) end @@ -104,6 +112,21 @@ Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for le back(sensitivity(y))[1] end +# This is inserted into @adjoint_final by ZygoteRules +@inline maybe_final(::Context{<:Any,true}, x) = maybe_final(x) +# The goal is to free CuArrays promptly. +@inline maybe_final(x::DenseArray) = finalize(x) + +# Without an @adjoint rule for this, some hessian tests fail: +# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_finalize_th), Nothing +# And if it in fact finalises, then other 2nd derivative tests fail. So do nothing: +@adjoint maybe_final(x) = nothing, _ -> nothing +@adjoint maybe_final(::Context, x) = nothing, _ -> nothing + +# Probably just for testing: +maybe_final(x::Vector) = resize!(x, 0) +maybe_final(x::Array{<:AbstractFloat}) = fill!(x, NaN) + """ withgradient(f, args...) withgradient(f, ::Params) @@ -129,40 +152,16 @@ julia> res.grad[w] ``` """ function withgradient(f, args...) - y, back = pullback(f, args...) + # Type parameters for Context are implicit=false, once=true + cx = Context{false,true}() + y, back = pullback(f, cx, args...) grad = back(sensitivity(y)) results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) (val=y, grad=results) end -# Param-style wrappers - -""" - gradient(() -> loss(), ps::Params) -> Grads - -Gradient with implicit parameters. Takes a zero-argument function, -and returns a dictionary-like container, whose keys are arrays `x in ps`. - -See also [`withgradient`](@ref) to keep the value `loss()`. - -```jldoctest; setup=:(using Zygote) -julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100]; - -julia> g = gradient(Params([x, y])) do - sum(x .* y .* z') - end -Grads(...) - -julia> g[x] -2×3 Matrix{Float64}: - 7.0 70.0 700.0 - 8.0 80.0 800.0 -julia> haskey(g, z) # only x and y are parameters -false -``` -""" -gradient +# Param-style wrappers """ Params([A, B]) @@ -391,6 +390,58 @@ function pullback(f, ps::Params) end end +""" + gradient(() -> loss(), ps::Params) -> Grads + +Gradient with implicit parameters. Takes a zero-argument function, +and returns a dictionary-like container, whose keys are arrays `x in ps`. + +See also [`withgradient`](@ref) to keep the value `loss()`. + +```jldoctest; setup=:(using Zygote) +julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100]; + +julia> g = gradient(Params([x, y])) do + sum(x .* y .* z') + end +Grads(...) + +julia> g[x] +2×3 Matrix{Float64}: + 7.0 70.0 700.0 + 8.0 80.0 800.0 + +julia> haskey(g, z) # only x and y are parameters +false +``` +""" +function gradient(f, ps::Params) + y, back = pullback(f, ps) + back(sensitivity(y)) +end + +""" + withgradient(f, ps::Params) -> Grads + +Returns both the value of the function and the [`gradient`](@ref), +as a named tuple. + +```jldoctest; setup=:(using Zygote) +julia> w = [3.0]; + +julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode +(val = 9.0, grad = Grads(...)) + +julia> res.grad[w] +1-element Vector{Float64}: + 6.0 +``` +""" +function withgradient(f, ps::Params) + y, back = pullback(f, ps) + (val=y, grad=back(sensitivity(y))) +end + # Code Reflection function code_ir(f, T) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index bf3692a30..fb71ffd9e 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -17,6 +17,8 @@ end chain_rrule_f = :chain_rrule end + # Here ZygoteRuleConfig{Zygote.Context{false, true}} is passed to chain_rrule + hascr, cr_edge = has_chain_rrule(cr_T) hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...)) diff --git a/src/lib/array.jl b/src/lib/array.jl index 4b8f90609..b94a426b0 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -191,7 +191,7 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)] ys = map(first, ys_and_backs) arg_ax = map(_tryaxes, args) function map_back(Δ) - if Base.issingletontype(F) && length(args) == 1 + ∇s = if Base.issingletontype(F) && length(args) == 1 Δarg = $mapfunc(((_,pb), δ) -> last_or_nothing(pb(δ)), ys_and_backs, Δ) # No unzip needed (nothing, Δarg) elseif Base.issingletontype(F) @@ -207,6 +207,8 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)] Δargs = map(_restore, Δf_and_args[2:end], arg_ax) (Δf, Δargs...) end + maybe_final(cx, ys_and_backs) + ∇s end map_back(::Nothing) = nothing return ys, map_back diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 58f7ecf99..4e94073a7 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -68,6 +68,17 @@ unbroadcast(x::Tuple{<:Any}, x̄::Nothing) = nothing unbroadcast(x::AbstractArray, x̄::Nothing) = nothing +# This variant is used only when we are certain x̄ is newly created +unbroadcast_final(::Context, x, x̄) = unbroadcast(x, x̄) +unbroadcast_final(::Context{false,true}, x, x̄::Nothing) = nothing +function unbroadcast_final(::Context{false,true}, x, x̄) + dx = unbroadcast(x, x̄) + if length(x) != length(x̄) + maybe_final(x̄) + end + dx +end + # Split Reverse Mode # ================== @@ -75,38 +86,38 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing # to do CSE, then broadcast-ify the expression so that the closure captures the # right arrays. -@adjoint broadcasted(::typeof(+), xs::Numeric...) = +@adjoint_final broadcasted(::typeof(+), xs::Numeric...) = broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) -@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y, +@adjoint_final broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y, Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ))) -@adjoint broadcasted(::typeof(-), x::Numeric) = .-x, +@adjoint_final broadcasted(::typeof(-), x::Numeric) = .-x, Δ -> (nothing, _minus(Δ)) _minus(Δ) = -Δ _minus(::Nothing) = nothing -@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, - Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x))) -@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) = +@adjoint_final broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, + Δ -> (nothing, unbroadcast_final(__context__, x, Δ .* conj.(y)), unbroadcast_final(__context__, y, Δ .* conj.(x))) +@adjoint_final broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) = _pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y)) -@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) = +@adjoint_final broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) = _pullback(__context__, *, x, y) -@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric) +@adjoint_final function broadcasted(::typeof(/), x::Numeric, y::Numeric) res = x ./ y - res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y))) + res, Δ -> (nothing, unbroadcast_final(__context__, x, Δ ./ conj.(y)), unbroadcast_final(__context__, y, .-Δ .* conj.(res ./ y))) end -@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) = +@adjoint_final broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) = _pullback(__context__, /, x, y) -@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p +@adjoint_final function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p y = Base.literal_pow.(^, x, exp) y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) end @adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ) -@adjoint function broadcasted(::typeof(tanh), x::Numeric) +@adjoint_final function broadcasted(::typeof(tanh), x::Numeric) y = tanh.(x) y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2)) end @@ -120,6 +131,7 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) +# These cannot use @adjoint_final, as they may return their input @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b y, Δ -> (nothing, Δ, nothing) @@ -152,7 +164,7 @@ end end end -@adjoint broadcasted(::Type{T}, x::Numeric) where {T<:Number} = +@adjoint_final broadcasted(::Type{T}, x::Numeric) where {T<:Number} = T.(x), ȳ -> (nothing, _project(x, ȳ),) # General Fallback @@ -185,13 +197,13 @@ _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true _dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types _dual_safearg(x) = false -@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} +@adjoint_final function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) # Avoid generic broadcasting in two easy cases: if T == Bool return (f.(args...), _ -> nothing) elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() - return broadcast_forward(f, args...) + return broadcast_forward(__context__, f, args...) end len = inclen(args) y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...) @@ -201,6 +213,8 @@ _dual_safearg(x) = false dxs = ntuple(len) do i collapse_nothings(map(StaticGetter{i}(), dxs_zip)) end + maybe_final(__context__, y∂b) + maybe_final(__context__, dxs_zip) (nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) end return y, ∇broadcasted @@ -245,15 +259,16 @@ function dual_function(f::F) where F end end -@inline function broadcast_forward(f, args::Vararg{Any,N}) where N +@inline function broadcast_forward(cx::Context, f, args::Vararg{Any,N}) where N valN = Val(N) out = dual_function(f).(args...) eltype(out) <: Dual || return (out, _ -> nothing) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) + unbroadcast_final(cx, args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) end + maybe_final(cx, out) # finalize for CUDA, when not inside jacobian (nothing, nothing, dargs...) # nothings for broadcasted & f end return y, bc_fwd_back @@ -264,20 +279,20 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, # so perhaps this can be deleted? Possible edge case here: # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415 - @adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) = - broadcast_forward(f, args...) + @adjoint_final broadcasted(::AbstractGPUArrayStyle, f, args...) = + broadcast_forward(__context__, f, args...) - @adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} = + @adjoint_final (::Type{T})(xs::Array) where {T <: AbstractGPUArray} = T(xs), Δ -> (convert(Array, Δ), ) - @adjoint function sum(xs::AbstractGPUArray; dims = :) + @adjoint_final function sum(xs::AbstractGPUArray; dims = :) placeholder = similar(xs) sum(xs, dims = dims), Δ -> (placeholder .= Δ,) end # Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible - @adjoint function sum(f, xs::AbstractGPUArray; kws...) + @adjoint_final function sum(f, xs::AbstractGPUArray; kws...) @assert !haskey(kws, :init) # TODO add init support (julia 1.6) return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs) end diff --git a/test/compiler.jl b/test/compiler.jl index c5ddf1f38..a288d7f10 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -31,11 +31,11 @@ y, back = pullback(badly, 2) bt = try back(1) catch e stacktrace(catch_backtrace()) end @test trace_contains(bt, nothing, "compiler.jl", 20) -if VERSION >= v"1.6-" - @test_broken trace_contains(bt, :badly, "compiler.jl", 24) -else +# if VERSION >= v"1.6-" +# @test_broken trace_contains(bt, :badly, "compiler.jl", 24) +# else @test trace_contains(bt, :badly, "compiler.jl", 24) -end +# end # Type inference checks diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 540d85e92..f1fed4277 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -270,8 +270,11 @@ end @test gradtest(dot, randn(rng, 10, 3), randn(rng, 10, 3)) end -@test gradtest(kron, rand(5), rand(3)) -@test gradtest(kron, rand(5), rand(3), rand(8)) +if VERSION < v"1.9-" # kron(::Vector...) no longer reshapes, needs a rule: + # https://github.com/JuliaDiff/ChainRules.jl/issues/684 + @test gradtest(kron, rand(5), rand(3)) + @test gradtest(kron, rand(5), rand(3), rand(8)) +end @test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))