Skip to content
Discussion options

You must be logged in to vote

I'm not sure what experiments you did or what you were measuring, but it is the case that lax.cond will not compute the output for the unused path at runtime, even if the boolean argument is not a constant. Here's a simple example demonstrating this:

import jax.numpy as jnp
import jax

@jax.jit
def f_slow(x):
  M = jnp.outer(x, x)
  evals, evecs = jnp.linalg.eigh(M)
  return evals.sum()

@jax.jit
def f_fast(x):
  return (x ** 2).sum()

@jax.jit
def f_cond(x):
  return jax.lax.cond(x.sum() > 0, f_slow, f_fast, x)

x = jnp.ones(2000)

f_slow(x).block_until_ready()
%timeit f_slow(x).block_until_ready()
# 1.72 s ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

f_fast(x).block_until…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@femtomc
Comment options

@jakevdp
Comment options

@femtomc
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants