Skip to content

Iterated gradients error #1327

@Vilin97

Description

@Vilin97

Package Version

Zygote v0.6.49

Julia Version

julia version 1.8.2

OS / Environment

Windows

Describe the bug

Iterated gradients seem to break arbitrarily. See the examples below.

Steps to Reproduce

using Zygote: gradient

g1(x,y) = (x+y)[1]
dxg1(x,y) = gradient(x -> g1(x,y), x)[1][1] #partial of g₁ wrt x₁
dxg1(ones(2), ones(2)) # 1.0, as expected
dxyg1(x,y) = gradient(y -> dxg1(x,y), y)[1][1] #partial of dxg1 wrt y₁
dxyg1(ones(2), ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}

g2(x,y) = transpose(x)*y
dxg2(x,y) = gradient(x -> g2(x,y), x)[1][1] 
dxyg2(x,y) = gradient(y -> dxg2(x,y), y)[1][1] 
dxyg2(ones(2), ones(2)) # 1.0, as expected

g3(x,y) = sum(x.*y)
dxg3(x,y) = gradient(x -> g3(x,y), x)[1][1] 
dxyg3(x,y) = gradient(y -> dxg3(x,y), y)[1][1] 
dxyg3(ones(2),ones(2)) # 1.0, as expected

g4(x,y) = x[1]*y[1] + x[2]*y[2]
dxg4(x,y) = gradient(x -> g4(x,y), x)[1][1] #partial wrt x
dxyg4(x,y) = gradient(y -> dxg4(x,y), y)[1][1] #mixed second derivative
dxyg4(ones(2),ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Vector{Float64}

Expected Results

I expected the code above to not error.

Observed Results

The code errors.

Relevant log output

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    ChainRulesadjoint -> rrule, and further integrationbugSomething isn't workingsecond orderzygote over zygote, or otherwise

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions