Integration between a JAX library and a foreign model #9012
Unanswered
gianlucadetommaso
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Suppose that we want to train in JAX an ML model$F(\theta, x)$ written in another framework, say PyTorch. Given a loss function $\sum_i L(\theta, D_i)$ over a jax.numpy array of parameters $\theta$ and data $D_i=(x_i, y_i)$ , one could write
where the Jacobian$\nabla_\theta F(\theta, x_i)$ is written in PyTorch and passed as a numpy array, while $\nabla_F L(\theta, D_i)$ is directly computed via
jax.grad
.More generally, this is useful not only for training, but for many cases of integration between a JAX library and a foreign model, where we do not want to force the user to translate their model into JAX.
The problem is that computing the Jacobian on the right-hand-side of the equation above is extremely expensive compared to a situation where I could directly compute the left-hand-side via standard
jax.grad
. The following code is a proof-of-concept for a classification model, where I compute the gradient of the loss function first with the left-hand-side and then with the right-hand-side methods. On my laptop, I get around0.15 seconds
to elapse the computation with the left-hand-side method and around360 seconds
with the right-hand-side one.Why this huge difference? Is there any way to accelerate the computation on the right-hand-side to be comparable to the direct gradient computation on the left?
Beta Was this translation helpful? Give feedback.
All reactions