Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

This is just a guess, but as mentioned in the practical notes for jax.remat, using scan can defeat some XLA rematerialization optimizations and thus cause OOMs. Maybe try applying jax.remat to the layer function, and see how it affects performance? You can use the policies described in those docs to control things more precisely.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AllanYangZhou
Comment options

Answer selected by AllanYangZhou
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants