-
I want to do something like this to count how many times an objective function is called as part of some library which doesn't count things for me: ncalls = 0
def foo(x):
global ncalls
ncalls = ncalls + 1
return x**2
some_optimization_library(foo, x)
ncalls # <-- now look at this Is there anything in this spirit which is possible in |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Aug 23, 2022
Replies: 1 comment 3 replies
-
You may be able to achieve what you have in mind using from jax.experimental import host_callback
from jax import jit
calls = 0
def register_call(arg, transforms):
global calls
calls += 1
@jit
def f(x):
host_callback.id_tap(register_call, None)
return x ** 2
@jit
def g(x):
out = 0
for i in range(5):
out += f(x)
return out
g(0.0)
print(calls)
# 5 Note however that this will likely result in performance degredation because of the required communication with the host. |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
marius311
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You may be able to achieve what you have in mind using
jax.experimental.host_callback
. For example:Note however that this will likely result in performance degredation because of the required communication with the host.