Skip to content
Discussion options

You must be logged in to vote

This is supported via "buffer donation", change your update function to the following:

def update(x):
    return x + 1

update = jax.jit(update, donate_argnums=0)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by imoneoi
Comment options

You must be logged in to vote
2 replies
@tomhennigan
Comment options

@imoneoi
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants