diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 4dda8559c..d60fa19e7 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -294,7 +294,8 @@ z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case z2d(dx, ::Any) = dx z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx -z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal) +z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? dx : map(Zygote.z2d, dx, primal) + #= # As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P} diff --git a/test/chainrules_tests.jl b/test/chainrules_tests.jl index 2d1ce2569..85fc28c08 100644 --- a/test/chainrules_tests.jl +++ b/test/chainrules_tests.jl @@ -414,6 +414,9 @@ end @test z2d_compiled.d === z2d_fallback.d @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b + + # empty dx => returns the dx + @test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) == ones(1, 0) end @testset "ChainRules translation" begin