Bitcast array from int32 to int64 #8784
-
Hi,
For context:
with jax.experimental.enable_x64():
a = jax.lax.bitcast_convert_type(...) As I am currently pretty much stuck here, I appreciate every little help I get! Thanks in advance! Best, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
One way to do this is with the import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
x32 = jnp.arange(20, dtype='int32').reshape(10, 2)
x64 = x32.view('int64').ravel()
print(x64)
# [ 4294967296 12884901890 21474836484 30064771078 38654705672 47244640266
# 55834574860 64424509454 73014444048 81604378642]
import tensorflow as tf
x32 = tf.reshape(tf.range(20, dtype='int32'), (10, 2))
x64 = tf.bitcast(x32, 'int64')
print(x64)
# tf.Tensor(
# [ 4294967296 12884901890 21474836484 30064771078 38654705672 47244640266
# 55834574860 64424509454 73014444048 81604378642], shape=(10,), dtype=int64) |
Beta Was this translation helpful? Give feedback.
One way to do this is with the
view
method of JAX arrays. For example: