Skip to content
Discussion options

You must be logged in to vote

To directly answer your question: yes, jit-compiled functions will generally allocate new memory for function outputs, even if the function happens to be an identity. This is due to the fact that in JAX, arrays are immutable and cannot share memory with other arrays, and the jit-of-identity case is not really important enough to implement an exception to the normal XLA computation path.

There are generally two ways around this: first, you could use an outer jit around your function that repeatedly calls the identity, and XLA will optimize-away the repeated identity calls and the copies they'll generate (you see this in your example). Second, on GPU or TPU (not CPU) you could use buffer do…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@simon-bachhuber
Comment options

Answer selected by simon-bachhuber
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants