Using a tensorflow tensor in a jax function #10141
Unanswered
KaleabTessera
asked this question in
Q&A
Replies: 1 comment 1 reply
-
The best way to share memory buffers between jax/tensorflow/pytorch is generally by using dlpack as an intermediary. For example: import jax
import jax.dlpack
import jax.numpy as jnp
import tensorflow as tf
def jax_func(x):
return jnp.sin(x)
x_tf = tf.range(5, dtype='float32')
print(x_tf)
# tf.Tensor([0. 1. 2. 3. 4.], shape=(5,), dtype=float32)
x_dl = tf.experimental.dlpack.to_dlpack(x_tf)
x_jax = jax.dlpack.from_dlpack(x_dl)
print(repr(x_jax))
# DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
y_jax = jax_func(x_jax)
print(repr(y_jax))
# DeviceArray([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ], dtype=float32)
y_dl = jax.dlpack.to_dlpack(y_jax)
y_tf = tf.experimental.dlpack.from_dlpack(y_dl)
print(y_tf)
# tf.Tensor([ 0. 0.84147096 0.9092974 0.14112 -0.7568025 ], shape=(5,), dtype=float32) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a simple jax function that takes in a tensor and does some pruning (this function - https://github.com/deepmind/dm-haiku/blob/1254b9648647536f3f5dae1435362802cf1eecfb/examples/mnist_pruning.py#L34).
I would like to use this function with a tensorflow tensor or tf variable (i.e.
value
in function is a tf tensor/variable). Currently I calltensor = jnp.arrray(tf_tensor)
before passing this tensor into the function, but this is super slow. Is there an alternative to this?Beta Was this translation helpful? Give feedback.
All reactions