-
Notifications
You must be signed in to change notification settings - Fork 65
Description
For some reason, I need to calculate a gradient of a gradient with TensorOperations.jl. Unfortunately, I encountered a BoundsError.
Here is the code to reproduce the issue:
using Zygote
using TensorOperations
function ff(x, y)
@tensor x2[-1, -2] := x[1, 2; -1] * conj(x[1, 2; -2])
@tensor y2[-1, -2] := y[-1, 1; 2] * conj(y[-2, 1; 2])
return @tensor x2[1, 2] * y2[1, 2]
end
function gg(x0, y0, c)
function g(_y)
g1 = Zygote.gradient(yy -> ff(x0, yy), _y)[1]
@tensor gb = conj(g1[1, 2; 3]) * c[1, 2; 3]
return gb
end
return Zygote.gradient(g, y0)
end
x0 = randn(2, 2, 2)
y0 = randn(2, 2, 2)
c = randn(2, 2, 2)
res = gg(x0, y0, c)ERROR: BoundsError: attempt to access 2×2 Matrix{Float64} at index [6]
Stacktrace:
[1] throw_boundserror(A::Matrix{Float64}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex
@ ./essentials.jl:916 [inlined]
[3] rrule
@ ~/.julia/packages/ChainRules/14CDN/src/rulesets/Base/indexing.jl:63 [inlined]
[4] rrule
@ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:138 [inlined]
... (and so on) ...After doing more tests, I found the problem from a let ... end block. This block uses a return statement, which causes the function to return early and leads to the BoundsError. This is very confusing because the return statement is syntactically correct. Furthermore, this isn't a problem for the first-order gradient.
Anyway, once I removed that return statement by hand, the error went away. The code was almost working then, aside from a few small issues (more on that below).
Changes
File: TensorOperationsChainRulesCoreExt.jl
- Removed all
returnfromlet ...endblock. For example:
dA = @thunk let
# ....
projectA(_dA) // The 'return' keyword is removed
end- add
promote_contractas non-differentiable and fix rrule fortensorscalar
- Before:
# ...
# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
function tensorscalar_pullback(Δc)
ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C))
return NoTangent(), fill!(ΔC, unthunk(Δc))
end
return tensorscalar(C), tensorscalar_pullback
end
#...- After:
# ...
@non_differentiable TensorOperations.promote_contract(args...)
# ...
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
projectC = ProjectTo(C)
function tensorscalar_pullback(Δc)
_Δc = unthunk(Δc)
return NoTangent(), projectC(_Δc)
end
return tensorscalar(C), tensorscalar_pullback
end
#....The changes seem to resolve the gradient issues. Is there a better way to do it?
Thanks for your help!