You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
(This was copied from a StackOverflow question, where the suggestion was that this discussion forum here might be a better place to ask this.)
TensorFlow
TensorFlow (in graph-mode) generates a computation graph and tf.gradients is implemented as an operation on the graph, which outputs a new graph for the gradient.
For execution, there are different options: It could directly operate on the graph, or it could compile it to XLA, or you could also transform it to TFLite, etc.
JAX
JAX operates on functions. jax.grad gets a function and returns a function.
For execution, there are different options, but I think the most common is to compile to XLA.
Question
So, this both sounds very similar to me. A computation graph is just another way to represent a function. Is there any conceptual difference?
I often see that people say that jax.vmap is one big advantage of JAX, but you can do just the same in TensorFlow, e.g. tf.vectorized_map.
From a very brief reading about how jax.grad works internally, it sounds like it is a bit different to tf.gradients. But there is no reason why you could not use the same algorithm in TF which is used in JAX, as far as I understand. But that is basically my question here.
Is there any type of algorithm you could only do efficiently in JAX and not in TF, or vice versa?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
(This was copied from a StackOverflow question, where the suggestion was that this discussion forum here might be a better place to ask this.)
TensorFlow
TensorFlow (in graph-mode) generates a computation graph and
tf.gradients
is implemented as an operation on the graph, which outputs a new graph for the gradient.For execution, there are different options: It could directly operate on the graph, or it could compile it to XLA, or you could also transform it to TFLite, etc.
JAX
JAX operates on functions.
jax.grad
gets a function and returns a function.For execution, there are different options, but I think the most common is to compile to XLA.
Question
So, this both sounds very similar to me. A computation graph is just another way to represent a function. Is there any conceptual difference?
I often see that people say that
jax.vmap
is one big advantage of JAX, but you can do just the same in TensorFlow, e.g.tf.vectorized_map
.From a very brief reading about how
jax.grad
works internally, it sounds like it is a bit different totf.gradients
. But there is no reason why you could not use the same algorithm in TF which is used in JAX, as far as I understand. But that is basically my question here.Is there any type of algorithm you could only do efficiently in JAX and not in TF, or vice versa?
Related questions
jax.grad
vstf.gradients
, but this is not really my question hereBeta Was this translation helpful? Give feedback.
All reactions