Replies: 1 comment 1 reply
-
Thanks for the question. 10x sounds suspicious! The easiest way to tell what's going on is to grab a profile. If you provide a fully runnable repro, i.e. your full timing script, I might be able to help with that. (What backend are you running on? GPU?) But one shot-in-the-dark guess is that you're timing compile time. That is, if you're timing this: loss, grad = jax.jit(jax.value_and_grad(batch_jax_loss))(p_params, *batch) That will include compile time. You might want to separate out the compile time, especially if you plan to evaluate this function more than once. Maybe something like: jax_rollout2 = jax.jit(jax_rollout2) # jit this function too if you haven't already
gradfun = jax.jit(jax.value_and_grad(batch_jax_loss))
tic = time.time()
batch = jax_rollout2(p_params, env, key)
loss, grad = gradfun(p_params, *batch)
loss.block_until_ready()
print('compile and first execution: ', time.time() - tic)
tic = time.time()
batch = jax_rollout2(p_params, env, key)
loss, grad = gradfun(p_params, *batch)
loss.block_until_ready()
print('second execution time: ', time.time() - tic) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Consider the two pytorch and jax functions below which are the same other than their function which computes what action to take (which in isolation take the same amount of time to compute) (i.e
a, log_prob = torch_policy(obs)
anda, log_prob = jax_policy(p_params, obs, key)
).However, when trying to backprop through them Jax is 10x slower than torch.
VS
I even tried to separate the rollout and the loss computation (which is an extra forward pass for the entire batch) and it was faster, but still ~4x slower than torch's.
I wouldn't expect such a big performance difference. Is there an optimization i'm missing or something? Thanks
Beta Was this translation helpful? Give feedback.
All reactions