-
Following code is an example from https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
# DeviceArray([11., 20., 29.], dtype=float32) If I add the
import torch
x = torch.arange(5,dtype=torch.float32)
w = torch.tensor([2.,3.,4])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(torch.dot(x[i-1:i+2], w))
return torch.tensor(output)
convolve(x, w)
# tensor([11., 20., 29.])
module = torch.jit.trace(convolve, (x,w))
'''
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:8: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:10: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
# Remove the CWD from sys.path while we load stuff.
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:10: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
# Remove the CWD from sys.path while we load stuff.
'''
module(x,w)
# tensor([11., 20., 29.])
x = torch.arange(10,15,dtype=torch.float32)
module(x,w)
# tensor([11., 20., 29.]) wrong output
import tensorflow as tf
x = tf.range(5,dtype=tf.float32)
w = tf.Variable([2.,3.,4.])
@tf.function(jit_compile=True)
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(tf.tensordot(x[i-1:i+2], w,1))
return tf.convert_to_tensor(output)
convolve(x, w)
# <tf.Tensor: shape=(3,), dtype=float32, numpy=array([11., 20., 29.], dtype=float32)> TensorFlow also does not recommend using any python side-effects similar to Jax and if I call the WARNING:tensorflow:6 out of the last 7 calls to <function convolve at 0x7f1221141cb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. is Jax also retraces similar to TF when the if refer to https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! JAX's tracing behavior here is perhaps a bit confusing, but essentially it will evaluate and flatten any Python control flow based on static quantities like array shapes. In the case of your convolve function. you can see this by calling from jax import make_jaxpr
make_jaxpr(convolve)(x, w)
You'll see there is no Based on this, I think the answers to your other questions may be more clear:
Typically in funcitonal programming "side-effects" refer to changes to global values accessed outside the function. Here the entire life-cycle of the list
Yes, jitted functions in JAX will be re-traced when faced with inputs of a new shape: this is true regardless of the content of the function.
"Static" in this context refers to the staticness of a quantity within a single function call. The shape of An example of a non-static shape would be something like this: @jit
def broken(x):
return jnp.arange(x[0]) This attemps to return an array whose shape depends on the first value of Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Thanks for the question! JAX's tracing behavior here is perhaps a bit confusing, but essentially it will evaluate and flatten any Python control flow based on static quantities like array shapes. In the case of your convolve function. you can see this by calling
jax.make_jaxpr
, which prints the jaxpr for the function: