[jax2tf] DecodeError: Error parsing message with type tensorflow.GraphDef when tf.saved_model.save
ing jax2tf
ed function
#11309
-
I've solved this issue - just creating this question + answer in case it's helpful for others searching for the answer to the same issue. I'm not sure if it's possible to make this error more informative - perhaps I should file a feature request on the tensorflow repo, but I'm not sure how "entangled" this is with JAX-related stuff. Here's the sort of conversion code that I'm using, per the my_model = tf.Module()
my_model.f = tf.function(jax2tf.convert(f), autograph=False, input_signature=[
tf.TensorSpec(shape=[32], dtype=tf.int32, name="tokens"),
])
tf.saved_model.save(my_model, '/content/mysavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) And here's the error mentioned in the title:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The solution in my case (since I don't need gradients) was simply to change the If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb Changing |
Beta Was this translation helpful? Give feedback.
The solution in my case (since I don't need gradients) was simply to change the
experimental_custom_gradients
option oftf.saved_model.save
toFalse
.If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb
Changing
experimental_custom_gradients
toFalse
in the above-linked notebook fixes it.