How to quantize pre-trained JAX models? #11267
Unanswered
deshwalmahesh
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a pre-trained
JAX
model for MAXIM: Image Enhancement. Now to reduce the runtime and use it in production, I'll have to quantize the weights. I have 2 options since there is no direct conversion to ONNX.Going for the second option, there's this function
tf.lite.TFLiteConverter.experimental_from_jax
Looking at this official example, the code block
it seems to be using the
params
from the model and a functionpredict
which in case are defined while model building and training itself aspredict:
and the params
My question is that how can I get these two required
params
andpredict
for my pre-trained model so that I can try to replicate example for my own model?Beta Was this translation helpful? Give feedback.
All reactions