how to use control flow under jit? #12441
Answered
by
jakevdp
yiminghwang
asked this question in
Q&A
-
Hi, there, I have a problem when I want to use jax.jit to speedup my codes, the simplified example is shown below, class A:
@partial(jit, static_argnums=(0, 1))
def funcA(self, x):
#some processing codes here, suppose it is y=x*0.5
y = 0.5*x
p = jnp.float32(y) # here the p is a DeviceArrary(()), generated by the above process
t = self.accept_or_not(p)
return t
def accept_or_not(self, p):
uni = jnp.asarray(np.random.uniform(low=0, high=1))
if jnp.greater(p,uni) == True: # if uni < p: #---------------->Here it raises an error
return True
else:
return False
a = A()
result = a.accept_or_not(jnp.float32(0.6)) In class A, I want to speed up the funcA by jax.jit, it raises an error about "Abstract tracer value encountered where concrete value is expected", how can I fix it? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Sep 21, 2022
Replies: 1 comment 3 replies
-
You can replace your def accept_or_not(self, p):
uni = jnp.asarray(np.random.uniform(low=0, high=1))
return jnp.where(uni < p, True, False) or, even more simply for this particular case (since you're just returning def accept_or_not(self, p):
uni = jnp.asarray(np.random.uniform(low=0, high=1))
return uni < p |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
yiminghwang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can replace your
if
statement with awhere
statement:or, even more simply for this particular case (since you're just returning
True
orFalse
you could do this: