Is it possible to apply assert_scalar_positive to a vector? The MWE below shows what I'd like to do; however, I cannot currently see how this is possible with Chex.
import jax.numpy as jnp
import jax
from chex import assert_scalar_positive
x_scaler = 1.
x_vector = jnp.array([1.,1.])
assert_scalar_positive(x_scaler) # Works
jax.vmap(assert_scalar_positive)(x_vector) # What I'd like