Skip to content
Discussion options

You must be logged in to vote

jax.lax.cond will trace (but not compute) both branches of its input, so both need to be viable computations. If you need logic that will branch at trace time based on static attributes, you can use normal Python control flow. For example:

def some_fn(x, y, condition):
  return (case_1 if condition else case_2)(x, y)
    
result = jax.tree_map(some_fn, tree1, tree2, same_shape)
print(result)
# (DeviceArray([0., 0.], dtype=float32), DeviceArray([1., 1., 1., 1., 1.], dtype=float32))

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@anh-tong
Comment options

@jakevdp
Comment options

Answer selected by anh-tong
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants