Why JAX shows different training pattern than Pytorch? #6325
-
Hi all, I noticed that Jax used together with dm-haiku shows different training dynamics than PyTorch, when using the same architecture, optimizer, and hyperparameters, initialization scheme, seeds, dataloaders, etc. Specifically, Jax appears to show faster convergence than PyTorch and has (comparably) higher accuracy after 100 epochs. The difference seems fairly significant and seems systematic across datasets and training runs. You can see the described behavior in the following two minimal implementations: In the implementations above, the following factors were considered to ensure the comparison is accurate:
I also took a closer look into the Jax and PyTorch repos and found that:
After accounting for all of the above, I’m really at a loss what could be causing the fairly different training dynamics, especially since it makes it difficult to reproduce results for models originally implemented in PyTorch. I would really appreciate if you could let me know if I missed something, and/or if you have an idea what could be causing the different training dynamics. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 7 replies
-
Thanks for the question! Unfortunately I don't have any hypotheses to offer. You could try generating the exact same initializations, and on one minibatch try breaking things down layer-by-layer to find where the first numerical divergence occurs. |
Beta Was this translation helpful? Give feedback.
-
@tomhennigan @mtthss @inoryy |
Beta Was this translation helpful? Give feedback.
-
+1 to what @mattjj suggested, starting from an identical initialisation is the way to go. I've been working on porting some torchvision checkpoints to Flax, and have some simple utilities to help check numerical equivalence of intermediate activations. Perhaps they'll be useful to you to debug your problem. Also, I noticed your L2 regularisation is implemented slightly differently between the two implementations: in PyTorch you use the sum of squares, whereas in JAX you use I'd also explicitly use a padding of Good luck! |
Beta Was this translation helpful? Give feedback.
-
Hi all! I have made some interesting discoveries based on kind suggestions from @mattjj @tomhennigan @n2cholas I did the same initiailization for JAX and Pytorch in the original custom CNN model, and discovered that the values start to diverge at the first forward pass. Then I crack down my network to a single CNN and FC layer to trace down the divergence, and find out that the forward pass at initialization is exactly the same. Check the notebooks below for details. However, interestingly, if CNN and FC layer are combined, then the forward pass values start to diverge. Check here It is so stange that CNN layer and FC layer shows the same pattern as Pytorch in seperate, but divergence occurs when they are combined. I wold appreciate if you could let me know if I missed something, and/or if you have an idea what could be causing this divergence. PS: I am new to this area, and I feel so happy that my question is being valued and the replies are so nice and helpful. I really appreciate that! :) |
Beta Was this translation helpful? Give feedback.
-
Hi all! The training dynamics is now similar, thanks to a very useful comment from @n2cholas The details can be found in this notebook here Thank you all for your replies! Cheers! |
Beta Was this translation helpful? Give feedback.
Hi all!
The training dynamics is now similar, thanks to a very useful comment from @n2cholas
The details can be found in this notebook here
Final Notebook
Thank you all for your replies! Cheers!