Skip to content
Discussion options

You must be logged in to vote

In a sense, JAX has two tiers of compilation ("op by op", jit) where TF has three ("op by op", tf.function without jit, tf.function with jit). We have sometimes talked about adding a third tier to JAX between "op by op" and jit, but have never yet found a use case where someone needs that middle tier.

If you have a compelling use case where you need an intermediate point between the two, we'd love to hear it. Usually the right design point if you care about speed is "use jit".

By the way, even if you don't jit your code, many JAX library functions use jit internally. There's no real notion of an "op" in JAX distinct from jit: we have primitives (the operations mainly in jax.lax) and most …

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@pipme
Comment options

Comment options

You must be logged in to vote
1 reply
@pipme
Comment options

Answer selected by pipme
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants