Replies: 1 comment
-
you may want to check out https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md , the tf2jax section in particular |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am using a program to optimize design parameters to minimize the drag of a ship. The force prediction uses a mixture of standard formulas (currently written in Numpy) and neural networks to predict the forces on the body. I want to use AutoDiff to improve the speed of the optimization. Is there any way to convert a Tensorflow Model into a way that can be used with Jax?
The TensorFlow models are already fully trained, ready to be used (deployed), so the only functionality I need Is the black box of numbers in -> number out. I don't need to modify the NN in any way.
One way I was thinking of trying was somehow being able to use generated XLA HLO graph of a TensorFlow Model inside a Jax function, but I haven't tried this yet. I'm new to understanding the backend of TensorFlow, so I'm going into this with a minimal understanding of what could work.
Any suggestions on what I could try would be welcome.
Beta Was this translation helpful? Give feedback.
All reactions