Skip to content
Discussion options

You must be logged in to vote

One way to do this is with the view method of JAX arrays. For example:

import jax
jax.config.update('jax_enable_x64', True)

import jax.numpy as jnp
x32 = jnp.arange(20, dtype='int32').reshape(10, 2)
x64 = x32.view('int64').ravel()
print(x64)
# [ 4294967296 12884901890 21474836484 30064771078 38654705672 47244640266
#  55834574860 64424509454 73014444048 81604378642]

import tensorflow as tf
x32 = tf.reshape(tf.range(20, dtype='int32'), (10, 2))
x64 = tf.bitcast(x32, 'int64')
print(x64)
# tf.Tensor(
# [ 4294967296 12884901890 21474836484 30064771078 38654705672 47244640266
#  55834574860 64424509454 73014444048 81604378642], shape=(10,), dtype=int64)

Replies: 1 comment 8 replies

Comment options

You must be logged in to vote
8 replies
@juliandwain
Comment options

@juliandwain
Comment options

@jakevdp
Comment options

@juliandwain
Comment options

@jakevdp
Comment options

Answer selected by juliandwain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants