Guidance on correct way to conditionally apply function #9233
-
Hi, I have a list of functions that with signature:
I was hoping to use
so I guess I alternatively tried modifying the functions so that their output concatenates the two outputs into a single
where the above error was generated when I tried running the following simple example:
Will I need to more directly use Also, as it may be relevant, this question originated when starting to look into finally addressing #5647 - the elimination of some redundant computation in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
I think from jax import jit, lax
def f1(x):
return x, x + 1
def f2(x):
return x + 1, x + 2
@jit
def switch_f(x):
return lax.cond(x < 2, f1, f2, x)
print(*switch_f(1))
# 1 2 |
Beta Was this translation helpful? Give feedback.
I think
jnp.piecewise
does require array outputs in order to work correctly, but you should be able to do this directly withlax.cond
; for example: