Skip to content
Discussion options

You must be logged in to vote

The issue here is that your function is not pure: setting attributes of self without returning it is a side-effect, and JAX transformations are not compatible with impure functions that rely on such side-effects (see JAX sharp bits: pure functions).

If you want to update your object in a JIT-compiled function, you need to return the updated object. The Equinox example given by @ToshiyukiBandai is one way to do this, but you don't need to add a dependency on Equinox to make your function pure; you could do it with JAX alone like this:

    @jit
    def some_method(self, diff_param):
        a = self.a
        self.a = diff_param * a
        return self, self.a

Then you can run your code lik…

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
3 replies
@Jinghong-Zhang
Comment options

@ToshiyukiBandai
Comment options

@Jinghong-Zhang
Comment options

Comment options

You must be logged in to vote
1 reply
@Jinghong-Zhang
Comment options

Answer selected by Jinghong-Zhang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants