-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
Description
using Zygote
using CUDA
f(A, I) = sum(A[I])
A = rand(4)
I = [1, 3, 1]
# CPU - everything is OK
Zygote.gradient(f, A, I)
# ==> ([2.0, 0.0, 1.0, 0.0], nothing)
# GPU - dA[1] is incorrect
Zygote.gradient(f, cu(A), cu(I))
# => (Float32[1.0, 0.0, 1.0, 0.0], nothing)
I believe CPU version comes from ChainRules.jl which correctly adds several derivatives to dA[1]
, but I'm not sure what code is used for CUDA version.
Here is how I came to this issue and how I try to resolve it in Yota.
jeremiedb and marius311