diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 92be71d34..50670bf43 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -178,15 +178,17 @@ julia> withjacobian(cumsum, [1,2,3]) ``` """ function withjacobian(f, args...) - y, back = pullback(_jvec∘f, args...) + y, back1 = pullback(f, args...) + yvec, back2 = pullback(_jvec, y) + back = dy -> back1(back2(dy)[1]) out = map(args) do x T = promote_type(eltype(x), eltype(y)) dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : x isa Number ? similar(y, T, length(y)) : nothing end - delta = _eyelike(y) - for k in LinearIndices(y) + delta = _eyelike(yvec) + for k in LinearIndices(yvec) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) dx isa AbstractArray || continue diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 691455491..1341d9993 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -77,6 +77,11 @@ end j6 = jacobian((x,y) -> abs2.(x .* y), [1+im, 2], 3+4im) @test j6[1][1,:] ≈ g6[1] @test j6[2][1] ≈ g6[2] + + # https://github.com/FluxML/Zygote.jl/issues/1506 + y7, g7 = Zygote.withjacobian(identity, rand(2, 3)); + @test size(y7) == (2,3) + @test only(g7) == I end @testset "jacobian(loss, ::Params)" begin