-
-
Notifications
You must be signed in to change notification settings - Fork 220
Open
Description
This is a weird issue and I have no idea about the cause.
MWE:
using Zygote, Test
function generate_i(i, a, d)
ifelse(i == 1, a, d)
end
function f_std(a, d)
N = 3
x = ones(N)
v = @. generate_i(1:N, a, d)
av = v .* x
return sum(av)
end
a = 0.2
d = 0.5
#Following should be (1.0, 2.0), while it is (0.0, 2.0)
@show res_std = Zygote.gradient(f_std, a, d)
dx = 1e-4
df_da = (f_std(a + dx, d) - f_std(a - dx, d)) / (2 * dx)
df_dd = (f_std(a, d + dx) - f_std(a, d - dx)) / (2 * dx)
@show res_fd = (df_da, df_dd)
@show analytical_res = (1.0, 2.0)
@test sum(abs.(res_std .- res_fd)) < 1e-2
@test all(res_std .≈ analytical_res)Zygote version: v0.7.10
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels