Skip to content
Discussion options

You must be logged in to vote

In order to do this you need to register your class c1 as a pytree. You can find more information here: https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees

For example:

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node

#A class
class c1():
  def __init__(self, x):
    self.x = x

def c1_flatten(obj):
  aux_data = None
  children = (obj.x,)
  return children, aux_data

def c1_unflatten(aux_data, children):
  return c1(*children)

register_pytree_node(c1, c1_flatten, c1_unflatten)

#A function
def f(x):
  y = x**2
  obj = c1(x) #an object
  return y,obj

jax.grad(f, has_aux=True)(0.0)
# (DeviceArray(0., dtype=float32), <__main__.c1 at 0x7fcdf…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by GaoyuanWu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants