run lax.while_loop for a certain amount of time #4839
-
Is it possible to run a An example of what I mean: define I've noticed that calling I would imagine that this isn't possible but I just thought I'd check! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I know it's been awhile, but I thought this was interesting. Here's an approach using the experimental Host Callback API: from jax import lax, jit, partial
from jax.experimental import host_callback as hcb
import numpy as np
import time
# Call expects a value to be passed, so we ignore first arg
def stop_after_time(_, time_limit=None, start_time=None):
cur_time = time.time()
return np.array((cur_time - start_time) < time_limit)
@partial(jit, static_argnums=(1, 2,))
def _cond_fn(input, start_time=None, time_limit=None):
# Partial application to avoid HCB Cache
hcb_fun = partial(stop_after_time, time_limit=time_limit, start_time=start_time)
# Send a blank value
return hcb.call(hcb_fun, (), result_shape=jax.ShapeDtypeStruct((), np.bool))
@jit
def body_fn(acc):
return acc + 1
@partial(jit, static_argnums=(1, 2,))
def sum_until_time(init_value, start_time, time_limit):
cond_fn = partial(_cond_fn, start_time=start_time, time_limit=time_limit)
return lax.while_loop(cond_fn, body_fn, init_value)
sum_until_time(0, time.time(), 10) |
Beta Was this translation helpful? Give feedback.
I know it's been awhile, but I thought this was interesting. Here's an approach using the experimental Host Callback API: