Disregard custom pytree nodes in tree_util functions? #20729
-
I have a Now I'm running into the issue that I can't get a tree structure of the class anymore without baking the static argument into it. This leads to wrong results when unflattening one object with the tree_def of another. (I need to flatten/unflatten because I'm feeding the object into a TensorFlow data loading function that only allows for a tuple of arrays as output. Because of that restriction I can't pass the actual Example: import jax
from typing import NamedTuple
# Class def
class MyNode(NamedTuple):
value: jax.numpy.ndarray
static_size: int
def mynode_flatten(node):
return [node.value], [node.static_size]
def mynode_unflatten(aux, xs):
return MyNode(xs[0], aux[0])
jax.tree_util.register_pytree_node(MyNode, mynode_flatten, mynode_unflatten)
# Store MyNode structure for unflattening later
inputs = MyNode(jax.numpy.zeros((1,)), 1)
global_def = jax.tree_util.tree_structure(inputs)
print(global_def)
tf_func = lambda x: x # TensorFlow function that may only return a tuple of tensors
def tf_pipeline(inp):
res = tf_func(lambda x: jax.tree_util.tree_flatten(x)[0])(inp)
out = jax.tree_util.tree_unflatten(global_def, res)
return out
res = tf_pipeline(inputs)
assert res == inputs # ok
new_inputs = MyNode(jax.numpy.zeros((2,)), 2)
new_res = tf_pipeline(new_inputs) # <- uses static_size from global_def
assert new_res == new_inputs, new_res # not ok Output: AssertionError: MyNode(value=Array([0., 0.], dtype=float32), static_size=1) What I would like is a |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
Why not use a local treedef rather than a global treedef? def tf_pipeline(inp):
treedef = None
def func(x):
nonlocal treedef
x_flat, treedef = jax.tree_util.tree_flatten(x)
return x_flat
res = tf_func(func)(inp)
out = jax.tree_util.tree_unflatten(treedef, res)
return out When I use this version of
You can do this using the |
Beta Was this translation helpful? Give feedback.
-
You could do it in a bit of a workaround, like: import jax
import jax.numpy as jnp
from typing import NamedTuple
# Class def
class MyNode(NamedTuple):
value: jnp.ndarray
static_size: int
def mynode_flatten(node):
return [node.value], [node.static_size]
def mynode_unflatten(aux, xs):
return MyNode(xs[0], aux[0])
jax.tree_util.register_pytree_node(MyNode, mynode_flatten, mynode_unflatten)
# Wrapper function
def tf_pipeline_wrapper(inp):
flattened_input = jax.tree_util.tree_flatten(inp)[0]
tf_output = tf_func(flattened_input) # Assuming tf_func is defined elsewhere
return jax.tree_util.tree_unflatten(global_def, tf_output)
# Usage
inputs = MyNode(jax.numpy.zeros((1,)), 1)
global_def = jax.tree_util.tree_structure(inputs)
print(global_def)
res = tf_pipeline_wrapper(inputs)
assert res == inputs # ok
new_inputs = MyNode(jax.numpy.zeros((2,)), 2)
new_res = tf_pipeline_wrapper(new_inputs)
assert new_res == new_inputs # ok This wrapper function tf_pipeline_wrapper takes an input of type MyNode, flattens it, passes the flattened structure to TensorFlow’s function (tf_func), and then unflattens the result using the global tree structure. This way, it bypasses the static size issue you have |
Beta Was this translation helpful? Give feedback.
Oh I see, I missed that you were overriding the default flattening rule. To answer your question: no, I don’t think there’s any way to have context-dependent flattening rules like what you have in mind.