Skip to content
Discussion options

You must be logged in to vote

JAX deliberately does not provide any API for this. Code with logic that branches on trace state is an anti-pattern and should be avoided: for example, it might cause autodiff to silently return wrong results if you use different codepaths when tracing and when not.

If your goal is to fail if a function is being traced, one way to do so would be to write code that fails under tracing; e.g. jnp.unique(0) would raise a ConcretizationTypeError under tracing. That may not be a very satisfying answer though.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cool-RR
Comment options

Answer selected by cool-RR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants