-
Hi, I try to make a code using import jax.numpy as jnp
import jax
tree1 = (jnp.ones((2,)), jnp.ones((3, )))
tree2 = (jnp.ones((2,)), jnp.ones((5,)))
shape1 = jax.tree_map(lambda x: jnp.array(x.shape), tree1)
shape2 = jax.tree_map(lambda x: jnp.array(x.shape), tree2)
same_shape = jax.tree_map(lambda x,y: jnp.allclose(x, y), shape1, shape2)
def case_1(a, b):
return a - b
def case_2(a, b):
return b
def some_fn(x, y, condition):
return jax.lax.cond(condition,
case_1,
case_2,
x,
y)
print("same_shape={}".format(same_shape))
result = jax.tree_map(some_fn, tree1, tree2, same_shape) Here is the colab link https://colab.research.google.com/drive/1Y9qVEQjS8Zxz-ozYDBZ8k0IBlRE9P4fR?usp=sharing When the Is the above code is viable in JAX? or Is it a bug here? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
def some_fn(x, y, condition):
return (case_1 if condition else case_2)(x, y)
result = jax.tree_map(some_fn, tree1, tree2, same_shape)
print(result)
# (DeviceArray([0., 0.], dtype=float32), DeviceArray([1., 1., 1., 1., 1.], dtype=float32)) |
Beta Was this translation helpful? Give feedback.
jax.lax.cond
will trace (but not compute) both branches of its input, so both need to be viable computations. If you need logic that will branch at trace time based on static attributes, you can use normal Python control flow. For example: