diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index aaec3951f..ab9acea44 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -117,27 +117,16 @@ for T_outer in (:Tuple, :NamedTuple) end """ - wrap_chainrules_input(x) - -Convert `x` from the format Zygote uses internally to differentials types ChainRules uses. -""" -@inline wrap_chainrules_input(x) = x -@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() -@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) - xp = map(wrap_chainrules_input, xs) - ChainRules.Tangent{Any, typeof(xp)}(xp) -end - -""" - ZBack{F}(back) <: Function + ZBack{F,P}(back::F, primals::P) <: Function Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. (A functor here is used rather than a closure to avoid boxing issues); """ -struct ZBack{F} <: Function +struct ZBack{F,P} <: Function back::F + primals::P end -@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) +@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(zygote2differential(dy, s.primals))) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) @inline (s::ZBack)(::Nothing) = nothing @@ -150,7 +139,7 @@ The pullback is appropriately wrapped up to follow Zygote conventions. """ @inline function chain_rrule(config, f, args...) y, back = rrule(config, f, args...) - return y, ZBack(back) + return y, ZBack(back, args) end @@ -163,7 +152,7 @@ As per [`chain_rrule`](@ref) but with support for kwargs. @inline function chain_rrule_kw(config, kwf, kwargs, f, args...) y, back = rrule(config, f, args...; kwargs...) function kw_zpullback(dy) - dxs = ZBack(back)(dy) + dxs = ZBack(back, args)(dy) if dxs === nothing # if dxs is nothing, then all partiaols are nothing # Zygote convention is a single nothing no mather how partials, if all are nothing return nothing