TracerBoolConversionError while using jax.lax.cond #20376
-
I am attempting to write something related to Conway's game of life in jax. Given a 2D array of 1s (alive cell) and 0s (dead cell). Here is the wikipedia explanation for the 'game':
Here is my code: #helper function
_at = lambda arr, i,j : jax.lax.dynamic_slice(arr, (i,j), (1,1))[0][0]
#tells if i,jth cell will live or die in next gen
def conway_i_jax(i, j, grid):
live_neighbors = _at(grid, i-1, j-1) + _at(grid, i, j-1)\
+ _at(grid, i+1 , j-1) + _at(grid, i-1 , j)\
+ _at(grid, i+1, j) + _at(grid, i-1, j+1)\
+ _at(grid, i, j+1) + _at(grid, i+1, j+1)
return jax.lax.cond((_at(grid,i,j) == 1) and (live_neighbors < 2 or live_neighbors > 3),
lambda x: jnp.array(0., dtype=jnp.float32),
lambda x: jax.lax.cond(x == 0 and live_neighbors == 3,
lambda x: jnp.array(1., dtype=jnp.float32),
lambda x: x,
x),
_at(grid,i,j))
#multiple i j
_conway_jax = jax.vmap(jax.vmap(conway_i_jax, in_axes=(0, None, None)),
in_axes=(None, 0, None))
def conway_jax(grid):
height, width = grid.shape
rows = jnp.arange(height, dtype=jnp.int32)
cols = jnp.arange(width, dtype=jnp.int32)
return _conway_jax(rows, cols, grid)
conway_jax(arr) Note: I am not worried about the boundaries for now. I get the error:
Pointing to the jax.cond. line. What is the issue here? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hi - the issue is your use of Python's The fix would look like this: return jax.lax.cond((_at(grid,i,j) == 1) & ((live_neighbors < 2) | (live_neighbors > 3)),
... Note the additional parentheses, which are required because of the relative operator precedence of logical and inequality ops. |
Beta Was this translation helpful? Give feedback.
-
Thanks! That works! |
Beta Was this translation helpful? Give feedback.
Hi - the issue is your use of Python's
and
andor
. This attempts to cast the operands on either side to boolean, which leads to this error. If you want traceable boolean logic, you need to use&
,|
, and~
instead ofand
,or
, andnot
(for what it's worth, this is simlar to what's required for element-wise operations on numpy arrays)The fix would look like this:
Note the additional parentheses, which are required because of the relative operator precedence of logical and inequality ops.