Speed up loop... #9467
Unanswered
jecampagne
asked this question in
Q&A
Replies: 3 comments 1 reply
-
Well I have managed to cook this code def do_sgd(params,data_batched, n_epochs):
n_batch = data_batched.shape[0]
loss=jnp.zeros(n_epochs*n_batch)
state = init(params)
def body(i, carry):
idx, loss, state = carry
obs = data_batched[i]
params = get_params(state)
val,g = dloglike(params, obs)
loss = loss.at[idx+i].set(val)
state = update(i, g, state)
carry = idx, loss, state
return carry
for epoch in range(n_epochs):
idx = epoch*n_batch
_,loss,state = lax.fori_loop(0, n_batch, body, (idx,loss,state))
return loss, state
key, key_mu, key_cov = rnd.split(key,3)
mu = rnd.normal(key_mu, shape=(5,))
untransformed_cov = rnd.normal(key_cov, shape=(5, 5))
params = mu, untransformed_cov
loss, state_opt = do_sgd(params,data_batched,5000) It's working but I do not know if its the most efficient... |
Beta Was this translation helpful? Give feedback.
1 reply
-
You could do something like this def main(state, i):
params = get_params(state)
def f(state, obs):
val, g = value_and_grad(loglike)(params, obs)
state = update(i, g, state)
return state, val
return lax.scan(f, state, data_batched)
state, loss = jit(partial(lax.scan, main))(state, jnp.arange(5000))
loss = loss.ravel() |
Beta Was this translation helpful? Give feedback.
0 replies
-
Great, it works perfectly. Thanks |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm pretty sure that this question will be 5' solved by experts.
Here is a working snippet
But it is quite slow few minutes on a CPU, and I get the double for-loops are horrible.
How I can speed up this code (nb. keeping the loss eval is here to plot diagnostics)
Thanks
Beta Was this translation helpful? Give feedback.
All reactions