Skip to content
Discussion options

You must be logged in to vote

The solution in my case (since I don't need gradients) was simply to change the experimental_custom_gradients option of tf.saved_model.save to False.

If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb

Changing experimental_custom_gradients to False in the above-linked notebook fixes it.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by josephrocca
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant