keras and jax: same neural net, same weights but different outputs ? #9279
-
Beta Was this translation helpful? Give feedback.
Answered by
jcpeterson
May 24, 2022
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
j-bac
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
opt_state = opt_update(0, grads, opt_state)
should beopt_state = opt_update(index, grads, opt_state)
whereindex
is a counter for gradient updates. adam updates depend on the index.