Skip to content
Discussion options

You must be logged in to vote

I know it's been awhile, but I thought this was interesting. Here's an approach using the experimental Host Callback API:

from jax import lax, jit, partial
from jax.experimental import host_callback as hcb
import numpy as np
import time

# Call expects a value to be passed, so we ignore first arg
def stop_after_time(_, time_limit=None, start_time=None):
  cur_time = time.time()
  return np.array((cur_time - start_time) < time_limit)

@partial(jit, static_argnums=(1, 2,))
def _cond_fn(input, start_time=None, time_limit=None):
  # Partial application to avoid HCB Cache
  hcb_fun = partial(stop_after_time, time_limit=time_limit, start_time=start_time)
  # Send a blank value
  return hcb.call(h…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@seanmor5
Comment options

@jeremiecoullon
Comment options

@jeremiecoullon
Comment options

Answer selected by jeremiecoullon
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants