Replies: 2 comments
-
additional experiments: i was suspicious on the global norm clipping being at the wrong place, but removing that doesn't help either. how i plotted gradient norm:
|
Beta Was this translation helpful? Give feedback.
0 replies
-
i found the issue - this is due to the fact that we thought the weights are initialized determinisitcally but it actually doesn't. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
i am migrating a tensorflow model to jax and seeing a weird behavior on matching training curves and looking for some help.
observation:
implementation details:
i used tf2jax for converting keras implemented model and metrics, optimizer is using optax.
tf train step:
jax train step:
also i tried changing loss_fn to not return a scaled total loss but doesn't seem to help or change anything.
Beta Was this translation helpful? Give feedback.
All reactions