how to assert with respect to traced values inside vmap? #12785
-
Sorry if this has been asked before. This is what I want, this way I can make sure the function is never used under unexpected conditions. def check_value(a):
assert a.shape == (1,) #works
assert a>0 #doesn't work
vmap(check_value)(jnp.ararge(10)-5) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
You cannot write assertions with respect to traced values in JAX transforms ( There's some more background on this in Common Gotchas: Python Control Flow – this section specifically discusses control flow, but assertions are basically control flow in disguise ( The How To Think In JAX document is also some good background reading to understand why JAX behaves this way. There may be workarounds depending on what you were trying to do originally: can you say more about what your goal was here? |
Beta Was this translation helpful? Give feedback.
-
Given the debugging tools that have spawned recently (e.g |
Beta Was this translation helpful? Give feedback.
Given the debugging tools that have spawned recently (e.g
jax.debug.print
), is there a way we can assert with respect to the values of traced arrays?