-
ContextI recently tried using import jax.numpy as jnp
x = jnp.array([1, 2, 2], dtype=jnp.int4)
print(jnp.abs(x)) # ❌ Fails with MLIR verification error This raises a very low-level error: MLIRError: Verification failed:
error: type of return operand 0 ('tensor<3xi4>') doesn't match function result type ('tensor<3xi8>') in function @abs
...
ValueError: Cannot lower jaxpr with verifier errors On the other hand: x = jnp.array([1, 2, 2], dtype=jnp.uint4)
print(jnp.abs(x)) # ✅ Works and returns [1 2 2] I understand from previous discussions that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for raising this issue: we should be able to catch errors like this during abstract evaluation. I've created #31644 from this discussion to track work on this. |
Beta Was this translation helpful? Give feedback.
Thanks for raising this issue: we should be able to catch errors like this during abstract evaluation. I've created #31644 from this discussion to track work on this.