Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! This is expected behavior. JAX transformations work by replacing JAX arrays with tracers (you get more intuition for this in How To Think In JAX). It's a really nice abstraction, but unfortunately it means that within transformed functions you can normally only call JAX code, and not external code like pytorch.

There is some new functionality to allow this sort of non-jax call using jax.pure_callback, but be aware that this requires a host sync and so e.g. interoperating between JAX and pytorch on a GPU will incur major performance penalties due to the movement of data to and from the host. I'm not sure if there's any example-driven documentation for pure_cal…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@tvercaut
Comment options

@jakevdp
Comment options

@tvercaut
Comment options

Answer selected by tvercaut
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants