Skip to content
Discussion options

You must be logged in to vote

Hi, thanks for the question! Unlike jit, make_jaxpr does not have any mechanism to support static arguments directly, so the best way to proceed is to close over the static arguments, for example using partial:

from functools import partial
make_jaxpr(partial(f, y=y))(x)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@trancenoid
Comment options

Answer selected by trancenoid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants