You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am experiencing an issue with taking the gradient of a Tensorflow-converted function using tensorflow.GradientTape when one of the function arguments is a custom pytree_node_class. There's a mismatch with the number of outputs and the number of generated gradients, presumably because the custom pytree_node_class is flattened into multiple parameters.
Below is a minimal(-ish) reproducable example:
importjax.numpyasjnpimporttensorflowastffromjax.experimentalimportjax2tffromjax.tree_utilimportregister_pytree_node_class@register_pytree_node_classclassSpecial(tf.Module):
def__init__(self, x, y):
self.x=xself.y=ydeftree_flatten(self):
return ((self.x, self.y), None)
@classmethoddeftree_unflatten(cls, aux_data, children):
returncls(*children)
defdo_thing(self, a: jnp.ndarray):
returnself.x*a+self.y@jax2tf.convertdefmy_function(my_obj: Special, a: jnp.ndarray):
returnjnp.sum(my_obj.do_thing(a))
my_obj=Special(tf.Variable(2.0), tf.Variable(3.0))
withtf.GradientTape() astape:
# This next line raises a ValueErrorresult=my_function(my_obj, tf.range(10))
grads=tape.gradient(result, [my_obj.trainable_variables])
Result:
Traceback (most recent call last):
File ".../minimal_example.py", line 34, in <module>
grads = tape.gradient(result, [my_obj.trainable_variables])
File ".../python3.8/site-packages/tensorflow/python/eager/backprop.py", line 1084, in gradient
flat_grad = imperative_grad.imperative_grad(
File ".../python3.8/site-packages/tensorflow/python/eager/imperative_grad.py", line 71, in imperative_grad
return pywrap_tfe.TFE_Py_TapeGradient(
File ".../python3.8/site-packages/tensorflow/python/ops/custom_gradient.py", line 572, in actual_grad_fn
raise ValueError(
ValueError: ('custom_gradient function expected to return', 3, 'gradients but returned', 2, 'instead.')
I really like the ability to register custom classes for JAX for improving code readability and maintainability. I am working on a project that requires interop with Tensorflow, however.
Is there a way to make this work while keeping my_function's call signature?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I am experiencing an issue with taking the gradient of a Tensorflow-converted function using
tensorflow.GradientTape
when one of the function arguments is a custom pytree_node_class. There's a mismatch with the number of outputs and the number of generated gradients, presumably because the custom pytree_node_class is flattened into multiple parameters.Below is a minimal(-ish) reproducable example:
Result:
I really like the ability to register custom classes for JAX for improving code readability and maintainability. I am working on a project that requires interop with Tensorflow, however.
Is there a way to make this work while keeping
my_function
's call signature?Beta Was this translation helpful? Give feedback.
All reactions