Replies: 1 comment 1 reply
-
XLA doesn't have any support for 1-bit variables, but recent versions of JAX do have some experimental suppot for 4-bit integers via the In [1]: import jax.numpy as jnp
In [2]: x = jnp.arange(4, dtype='int4')
In [3]: print(x)
[0 1 2 3] The only way I know of to use bit arrays is via the In [4]: x = jnp.array([1, 0, 1, 1, 0, 0, 1, 0], dtype='uint8')
In [5]: bits = jnp.packbits(x)
In [6]: print(bits)
[178]
In [7]: print(jnp.unpackbits(bits))
[1 0 1 1 0 0 1 0] |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Can I use JAX to compress 32 boolean variables into one uint32 variable, like C++'s bitset functionality? I want to implement architectures like binary neural networks in JAX.
For example, I have a boolean array of length 256 that I want to compress into a uint32 array of length 8.
Beta Was this translation helpful? Give feedback.
All reactions