-
What is the recommended way to handle multiple conditions in jitted functions? key = random.PRNGKey(0)
x = random.normal(key,(100000,))
y = random.normal(key,(100000,))
def p1(x,y):
return x > 10 and y > 10
def p2(x,y):
c1 = lax.cond(x > 10 and y > 10, lambda _: True, lambda _: False, None)
return c1
def p3(x,y):
c1 = lax.cond(x > 10 * y > 10, lambda _: True, lambda _: False, None)
return c1
def p4(x,y):
c1 = lax.cond(x > 10, lambda _: True, lambda _: False, None)
c2 = lax.cond(y > 10, lambda _: True, lambda _: False, None)
return c1 and c2
def p5(x,y):
c1 = lax.cond(x > 10, lambda _: True, lambda _: False, None)
c2 = lax.cond(y > 10, lambda _: True, lambda _: False, None)
return c1*c2
def p6(x,y):
c1 = lax.cond(x > 10, lambda _: 1, lambda _: 0, None)
c2 = lax.cond(y > 10, lambda _: 1, lambda _: 0, None)
return c1*c2 != 0
#pv1 = jit(vmap(p1,(0,0)))
#pv2 = jit(vmap(p2,(0,0)))
#pv3 = jit(vmap(p3,(0,0)))
#pv4 = jit(vmap(p4,(0,0)))
pv5 = jit(vmap(p5,(0,0)))
pv6 = jit(vmap(p6,(0,0)))
#%timeit pv1(x,y).block_until_ready()
#%timeit pv2(x,y).block_until_ready()
#%timeit pv3(x,y).block_until_ready()
#%timeit pv4(x,y).block_until_ready()
%timeit pv5(x,y).block_until_ready()
%timeit pv6(x,y).block_until_ready() Converting the result from boolean back to seem to incur some overhead:
Is there a way to use logical operators in JAX instead of having to mimic them using arithmetic on integers? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
You could use jnp.logical_and and friends, or their operator form like Not so sure about the overhead part as I don't see any on my machine. |
Beta Was this translation helpful? Give feedback.
-
Thanks :) |
Beta Was this translation helpful? Give feedback.
You could use jnp.logical_and and friends, or their operator form like
a & b
. Do be aware of their precedence though, as they bind more tightly than comparison operators. Sox > 10 and y > 10
should be(x > 10) & (y > 10)
, and the parentheses are not omittable.Not so sure about the overhead part as I don't see any on my machine.