Skip to content
Discussion options

You must be logged in to vote

As discussed in another thread, JIT-compiling the code resolves the issue. After it, JAX version actually works slightly faster than PyTorch.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by dfdx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant