-
I read a previous discussion about the runtime cost of I did a few experiments -- and it seems like this primitive is not aggressively pruned, when the first Should I be worried about this at all from an optimization perspective? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I'm not sure what experiments you did or what you were measuring, but it is the case that import jax.numpy as jnp
import jax
@jax.jit
def f_slow(x):
M = jnp.outer(x, x)
evals, evecs = jnp.linalg.eigh(M)
return evals.sum()
@jax.jit
def f_fast(x):
return (x ** 2).sum()
@jax.jit
def f_cond(x):
return jax.lax.cond(x.sum() > 0, f_slow, f_fast, x)
x = jnp.ones(2000)
f_slow(x).block_until_ready()
%timeit f_slow(x).block_until_ready()
# 1.72 s ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
f_fast(x).block_until_ready()
%timeit f_fast(x).block_until_ready()
# 5.89 µs ± 87.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# positive x: should take slow path
f_cond(x).block_until_ready()
%timeit f_cond(x).block_until_ready()
# 1.69 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# negative x: should take fast path
%timeit f_cond(-x).block_until_ready()
# 536 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) When |
Beta Was this translation helpful? Give feedback.
I'm not sure what experiments you did or what you were measuring, but it is the case that
lax.cond
will not compute the output for the unused path at runtime, even if the boolean argument is not a constant. Here's a simple example demonstrating this: