You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just started messing around with the library and followed the tutorial here
(I am running on plain CPU, just FYI)
The vanilla version did not quite work out, the gradient seemed to be unstable and close to zero. Essentially I did not get any "training" effect, the parameters stayed constant for all epochs.
At this point I started questioning the model, specifically the final part seemed odd:
defpredict(params, image):
# per-example predictionsactivations=imageforw, binparams[:-1]:
outputs=jnp.dot(w, activations) +bactivations=relu(outputs)
final_w, final_b=params[-1]
logits=jnp.dot(final_w, activations) +final_breturnlogits-logsumexp(logits) # <== what is this supposed to accomplish?
Because of this, I started changing the model slightly. Specifically I tried mirroring some of PyTorch's starter examples because they are already pretty close to the jax example.
Specifically this made me change this to a somewhat similar cross-entropy loss style learning.
Eventually this seemed to help to some degree. The gradient was a little bit more stable now and I could actually see training "progress" (something like 10% accuracy or so).
Since this was not even remotely close to the expected 90% accuracy, I simplified to a simple softmax(logits) prediction alongside a simple quadratic loss function:
Finally this resulted in the expected training progress, hitting 90% accuracy shortly after training start.
So what is going on here? Which gotcha is to avoid here, are the log evaluations numerically unstable? Is there something obvious that I am missing? Thanks for any help
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hey,
I just started messing around with the library and followed the tutorial here
(I am running on plain CPU, just FYI)
The vanilla version did not quite work out, the gradient seemed to be unstable and close to zero. Essentially I did not get any "training" effect, the parameters stayed constant for all epochs.
At this point I started questioning the model, specifically the final part seemed odd:
Because of this, I started changing the model slightly. Specifically I tried mirroring some of PyTorch's starter examples because they are already pretty close to the jax example.
Specifically this made me change this to a somewhat similar cross-entropy loss style learning.
Eventually this seemed to help to some degree. The gradient was a little bit more stable now and I could actually see training "progress" (something like 10% accuracy or so).
Since this was not even remotely close to the expected 90% accuracy, I simplified to a simple
softmax(logits)
prediction alongside a simple quadratic loss function:Finally this resulted in the expected training progress, hitting 90% accuracy shortly after training start.
So what is going on here? Which gotcha is to avoid here, are the
log
evaluations numerically unstable? Is there something obvious that I am missing? Thanks for any helpBeta Was this translation helpful? Give feedback.
All reactions