Skip to content
Discussion options

You must be logged in to vote

I think jnp.piecewise does require array outputs in order to work correctly, but you should be able to do this directly with lax.cond; for example:

from jax import jit, lax

def f1(x):
  return x, x + 1

def f2(x):
  return x + 1, x + 2

@jit
def switch_f(x):
  return lax.cond(x < 2, f1, f2, x)

print(*switch_f(1))
# 1 2

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@DanPuzzuoli
Comment options

@jakevdp
Comment options

@soraros
Comment options

@DanPuzzuoli
Comment options

Answer selected by DanPuzzuoli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants