-
Hi all. #A class
class c1():
def __init__(self,x):
self.x = x
#A function
def f(x):
y = x**2
obj = c1(x) #an object
return y,obj And I wish to calculate
This will fail as |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 19, 2022
Replies: 1 comment
-
In order to do this you need to register your class For example: import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
#A class
class c1():
def __init__(self, x):
self.x = x
def c1_flatten(obj):
aux_data = None
children = (obj.x,)
return children, aux_data
def c1_unflatten(aux_data, children):
return c1(*children)
register_pytree_node(c1, c1_flatten, c1_unflatten)
#A function
def f(x):
y = x**2
obj = c1(x) #an object
return y,obj
jax.grad(f, has_aux=True)(0.0)
# (DeviceArray(0., dtype=float32), <__main__.c1 at 0x7fcdf7bd1fd0>) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
GaoyuanWu
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In order to do this you need to register your class
c1
as a pytree. You can find more information here: https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytreesFor example: