Skip to content
Discussion options

You must be logged in to vote

You may be able to achieve what you have in mind using jax.experimental.host_callback. For example:

from jax.experimental import host_callback
from jax import jit

calls = 0

def register_call(arg, transforms):
  global calls
  calls += 1

@jit
def f(x):
  host_callback.id_tap(register_call, None)
  return x ** 2

@jit
def g(x):
  out = 0
  for i in range(5):
    out += f(x)
  return out

g(0.0)
print(calls)
# 5

Note however that this will likely result in performance degredation because of the required communication with the host.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@marius311
Comment options

@sharadmv
Comment options

@marius311
Comment options

Answer selected by marius311
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants