-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integrationbugSomething isn't workingSomething isn't workingsecond orderzygote over zygote, or otherwisezygote over zygote, or otherwise
Description
Package Version
Zygote v0.6.49
Julia Version
julia version 1.8.2
OS / Environment
Windows
Describe the bug
Iterated gradients seem to break arbitrarily. See the examples below.
Steps to Reproduce
using Zygote: gradient
g1(x,y) = (x+y)[1]
dxg1(x,y) = gradient(x -> g1(x,y), x)[1][1] #partial of g₁ wrt x₁
dxg1(ones(2), ones(2)) # 1.0, as expected
dxyg1(x,y) = gradient(y -> dxg1(x,y), y)[1][1] #partial of dxg1 wrt y₁
dxyg1(ones(2), ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}
g2(x,y) = transpose(x)*y
dxg2(x,y) = gradient(x -> g2(x,y), x)[1][1]
dxyg2(x,y) = gradient(y -> dxg2(x,y), y)[1][1]
dxyg2(ones(2), ones(2)) # 1.0, as expected
g3(x,y) = sum(x.*y)
dxg3(x,y) = gradient(x -> g3(x,y), x)[1][1]
dxyg3(x,y) = gradient(y -> dxg3(x,y), y)[1][1]
dxyg3(ones(2),ones(2)) # 1.0, as expected
g4(x,y) = x[1]*y[1] + x[2]*y[2]
dxg4(x,y) = gradient(x -> g4(x,y), x)[1][1] #partial wrt x
dxyg4(x,y) = gradient(y -> dxg4(x,y), y)[1][1] #mixed second derivative
dxyg4(ones(2),ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Vector{Float64}
Expected Results
I expected the code above to not error.
Observed Results
The code errors.
Relevant log output
No response
Metadata
Metadata
Assignees
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integrationbugSomething isn't workingSomething isn't workingsecond orderzygote over zygote, or otherwisezygote over zygote, or otherwise