jax.jit for in-place mutation of class attributes #21611
-
In the documentation of jax.jit on class methods how-to-use-jit-with-methods, there is a line "If your class relies on in-place mutations (such as setting self.attr = ... within its methods), then your object is not really “static” and marking it as such may lead to problems. Fortunately, there’s another option for this case." and I suppose the "another option" refers to the pytree registration strategy. So I wrote my code like this from jax import jit
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class CustomClass:
def __init__(self, a, b):
self.a = a
self.b = b
# @jit # showing difference between jitting and not jitting
def some_method(self, diff_param):
a = self.a
self.a = diff_param * a
return self.a
def tree_flatten(self):
children = (self.a, self.b) # dynamic values
aux_data = {}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
a = jnp.array([1.0, 2.0])
b = jnp.array([1.0, 2.0])
diff_param = 4.0
obj = CustomClass(a, b)
a = obj.some_method(diff_param)
print(obj.a) if the class method is jitted, the result will be |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
It seems like Equinox would be helpful (https://docs.kidger.site/equinox/). Let me know if this fits to your usage.
|
Beta Was this translation helpful? Give feedback.
-
The issue here is that your function is not pure: setting attributes of If you want to update your object in a JIT-compiled function, you need to return the updated object. The Equinox example given by @ToshiyukiBandai is one way to do this, but you don't need to add a dependency on Equinox to make your function pure; you could do it with JAX alone like this: @jit
def some_method(self, diff_param):
a = self.a
self.a = diff_param * a
return self, self.a Then you can run your code like this: obj, a = obj.some_method(param) And the returned object will have the updated parameters. |
Beta Was this translation helpful? Give feedback.
The issue here is that your function is not pure: setting attributes of
self
without returning it is a side-effect, and JAX transformations are not compatible with impure functions that rely on such side-effects (see JAX sharp bits: pure functions).If you want to update your object in a JIT-compiled function, you need to return the updated object. The Equinox example given by @ToshiyukiBandai is one way to do this, but you don't need to add a dependency on Equinox to make your function pure; you could do it with JAX alone like this:
Then you can run your code lik…