diff --git a/Project.toml b/Project.toml index d2357cb26..20895dd9d 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.1" SpecialFunctions = "1.6, 2" Statistics = "1" Tracker = "0.2" -ZygoteRules = "0.2.4" +ZygoteRules = "0.2.5" julia = "1.6" [extras] diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index c09d6db31..80fd9b477 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing tailmemaybe(x::Tuple) = Base.tail(x) +""" + pullback(f, args...) + pullback(f, ::Params) + +Returns the value of the function `f` and a back-propagator function, +which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`, +the derivative (for scalar `x`) or gradient. + +```julia +y, back = pullback(f, args...) +∇ = back(seed) +``` + +`back` must be called with a start value `seed` matching the output of `f(args...)`. +If `f(args...)` returns a number, `seed` should be a number. +If `f(args...)` returns an array, `seed` should be an equally-sized array. + +See also [`withgradient`](@ref) to obtain the value and gradients in one call, +and [`gradient`](@ref) for obtaining just the gradients. + +```jldoctest; setup=:(using Zygote) +julia> y, back = pullback(*, 2.0, 3.0, 5.0); + +julia> y +30.0 + +julia> back(1.0) +(15.0, 10.0, 6.0) + +julia> back(2.0) +(30.0, 20.0, 12.0) + +julia> y, back = pullback(x -> [x, x], 1.0); + +julia> y +2-element Vector{Float64}: + 1.0 + 1.0 + +julia> back([1.0, 1.0]) +(2.0,) + +julia> back([2.0, nothing]) +(2.0,) +``` +""" @inline pullback(f, args...) = pullback(f, Context(), args...) function pullback(f, cx::AContext, args...) y, back = _pullback(cx, f, args...) @@ -67,11 +113,16 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.") sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") +# Preserves output as tuple when gradients are collapsed +_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) +_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx) + """ gradient(f, args...) Returns a tuple containing `∂f/∂x` for each argument `x`, the derivative (for scalar `x`) or the gradient. +If no gradient is defined, `∂f/∂x` will be `nothing`. `f(args...)` must be a real number, see [`jacobian`](@ref) for array output. @@ -95,7 +146,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - isnothing(grad) ? nothing : map(_project, args, grad) + return _project_all(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -109,7 +160,7 @@ end withgradient(f, ::Params) Returns both the value of the function and the [`gradient`](@ref), -as a named tuple. +as a named tuple. ```jldoctest; setup=:(using Zygote) julia> y, ∇ = withgradient(/, 1, 2) @@ -161,7 +212,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) + results = _project_all(args, grad) (val=y, grad=results) end @@ -304,7 +355,7 @@ end Grads(...) Dictionary-like container returned when taking gradients with -respect to implicit parameters. For an array `W`, appearing +respect to implicit parameters. For an array `W`, appearing within `Params([W, A, B...])`, the gradient is `g[W]`. """ struct Grads @@ -321,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads} # Dictionary interface. # Don't use the IdDict directly since it may contain some spurious pairs. -Base.haskey(gs::Grads, x) = x ∈ gs.params +Base.haskey(gs::Grads, x) = x ∈ gs.params Base.keys(gs::Grads) = gs.params Base.values(gs::Grads) = (gs.grads[p] for p in gs.params) @@ -381,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs) broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs) function materialize!(gs1::Grads, gs2::Grads) - issetequal(gs1.params, gs2.params) || + issetequal(gs1.params, gs2.params) || throw(ArgumentError("Expected Grads objects with the same Params.")) for p in gs1.params gs1[p] = gs2[p] @@ -421,6 +472,9 @@ function pullback(f, ps::Params) end end +# No conversion required here +_project_all(_, dx::Grads) = dx + # Code Reflection function code_ir(f, T) diff --git a/test/lib/number.jl b/test/lib/number.jl index ce0a64bef..77756387d 100644 --- a/test/lib/number.jl +++ b/test/lib/number.jl @@ -3,8 +3,8 @@ @test gradient(floor, 1) === (0.0,) @test gradient(ceil, 1) === (0.0,) @test gradient(round, 1) === (0.0,) - @test gradient(hash, 1) === nothing - @test gradient(div, 1, 2) === nothing + @test gradient(hash, 1) === (nothing,) + @test gradient(div, 1, 2) === (nothing, nothing) end @testset "basics" begin diff --git a/test/structures.jl b/test/structures.jl index 5a951a621..cdba138c4 100644 --- a/test/structures.jl +++ b/test/structures.jl @@ -64,5 +64,5 @@ end end m, b = Zygote._pullback(Zygote.Context(), nameof, M) - @test b(m) == (nothing, nothing) + @test b(m) === nothing end