Skip to content
Discussion options

You must be logged in to vote

I'm not sure I understand your question. Is it possible that you meant for jitted_mult to be defined like this?

    @jax.jit
    def jitted_mult(input_var):
        return weights.dot(input_var)

If that's the case, then the reason it's not respecting the update is because your function is not pure (see JAX Sharp Bits: Pure Functions). In your case, the function is not pure because the output depends on an input that is not explicitly passed to the function. This violates the assumptions made by jit and other JAX transformations, which leads to unexpected behavior.

To fix it I would make this implicit input explicit, so that your function is pure. It might look something like this:

def wra…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@IanQS
Comment options

Comment options

You must be logged in to vote
2 replies
@IanQS
Comment options

@jakevdp
Comment options

Answer selected by IanQS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants