Skip to content
Discussion options

You must be logged in to vote

The first time selu_jit() is called, it is traced & compiled using abstract values representing the input arrays, and then once compilation is finished, the actual arrays are passed to the compiled function for execution. The tracing and compilation is done with abstract values representing the input array, not with the input array itself.

As a side note, very early versions of JAX did this differently, and the first function execution was unoptimized, with only subsequent executions being compiled. But JAX has not worked that way for a long time.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@GoktugGuvercin
Comment options

@jakevdp
Comment options

@GoktugGuvercin
Comment options

Answer selected by GoktugGuvercin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants