Skip to content
Discussion options

You must be logged in to vote

Two things to know: ShapedArray tracers basically contain references to the array's shape and dtype, while ConcreteArray tracers contain references to the array's shape, dtype, as well as its contents.

jax.grad uses concrete tracers in order to support value-based control flow during tracing. For example:

import jax

def f(x):
  if x < 0:
    return 0.0
  return x ** 2

print(jax.grad(f)(1.0))
# 2.0

If grad did not use concrete tracers in this case, the result of the condition x < 0 would not be knowable at trace-time, and we could not differentiate functions that use this kind of control flow.

By contrast, jaxprs are created entirely based on abstract properties of arrays (namely shape and

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@riven314
Comment options

@riven314
Comment options

@jakevdp
Comment options

@soraros
Comment options

@riven314
Comment options

Answer selected by riven314
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants