Custom PyTree Tree Map Error #19449
-
I have been trying to create a PyTree superclass that can recognize all attribute JAX arrays and register them as leaves for JAX transformations. End goal is a neural network Module class like import jax
import tree
from jax import tree_util
from collections import defaultdict
class Placeholder:
pass
@tree_util.register_pytree_node_class
class Module:
_set_name = defaultdict(int)
def __init__(self, name="Module") -> None:
# Name Handling
self.name = f"{name}_{Module._set_name[name]}"; Module._set_name[name] += 1
self._no_grad:set = set() # Handles the frozen trainable_vars.
def __init_subclass__(cls) -> None:
tree_util.register_pytree_node_class(cls)
def freeze(self, key):
self._no_grad.add(key)
def unfreeze(self, key):
if key in self._no_grad:
self._no_grad.remove(key)
def tree_flatten(self):
# Copying attribute dict for manipulation
instance_vars = vars(self).copy()
frozen_vars:dict = {}
# Removing frozen nodes:
for key in self._no_grad:
frozen_vars[key] = instance_vars.pop(key)
leaves = []
def search(x):
if isinstance(x, jax.Array) or isinstance(x, Module):
leaves.append(x)
return Placeholder()
else:
return x
aux_data = tree.map_structure(search, instance_vars)
aux_data.update(frozen_vars) # Reupdating frozen nodes to aux_data
return leaves, aux_data
@classmethod
def tree_unflatten(cls, aux_data, leaves):
instance = cls.__new__(cls) # Creating a new instance
vars(instance).update(aux_data) # Momentarily updating the instance. Probably will remove as Unnecessary
# Updating the aux_node's PlaceHolders with updated arrays.
aux_nodes = tree.flatten(aux_data)
leaves_new = list(leaves).copy()
for index, val in enumerate(aux_nodes):
if isinstance(val, Placeholder) and len(leaves_new):
aux_nodes[index] = leaves_new[0]
leaves_new.pop(0)
# Recreating instance attribute dict and updating:
instance_dict = tree.unflatten_as(aux_data, aux_nodes)
vars(instance).update(instance_dict)
return instance
def __repr__(self) -> str:
return str(vars(self)) I am using a class PlaceHolder to denote where to input my arrays back in the unflattening procedure. This implementation works when create a class and computing it's gradient, like this: from typing import Any
class Dense(Module):
def __init__(self, name="Dense") -> None:
super().__init__(name)
self.a = jax.numpy.array([1., 2., 3.])
self.b = [self.a, 'a', jax.numpy.array([1., 2., 2.])]
self.kernel = jax.numpy.array([[2.]])
self.bias = jax.numpy.array([1.])
def __call__(self, x) -> Any:
return x @ self.kernel + self.bias
class Sequential(Module):
def __init__(self, name="Sequential") -> None:
super().__init__(name)
self.a = jax.numpy.array([1., 2., 3.])
self.b = [self.a, 'a', jax.numpy.array([1., 2., 2.])]
self.c = Dense()
def __call__(self, x) -> Any:
return self.c(x)
instance = Sequential()
instance.freeze('c') # Testing _no_grad
# Computing gradient
x = jax.numpy.array([[1.]])
def fun(model):
return model(x).sum()
grad = jax.grad(fun)(instance)
print(grad) However, it returns an error, when I try to optimize my params with the grads: params = jax.tree_map(lambda x, y: x + y, instance, grad) ValueError: Mismatch custom node data: {'name': 'Sequential_1', '_no_grad': {'c'}, 'a': <__main__.Placeholder object at 0x7f728010b410>, 'b': [<__main__.Placeholder object at 0x7f725c52b4d0>, 'a', <__main__.Placeholder object at 0x7f725c50ee10>], 'c': {'name': 'Dense_1', '_no_grad': set(), 'a': Array([1., 2., 3.], dtype=float32), 'b': [Array([1., 2., 3.], dtype=float32), 'a', Array([1., 2., 2.], dtype=float32)], 'kernel': Array([[2.]], dtype=float32), 'bias': Array([1.], dtype=float32)}} != {'name': 'Sequential_1', '_no_grad': {'c'}, 'a': <__main__.Placeholder object at 0x7f7280166090>, 'b': [<__main__.Placeholder object at 0x7f7280165f10>, 'a', <__main__.Placeholder object at 0x7f72801648d0>], 'c': {'name': 'Dense_1', '_no_grad': set(), 'a': Array([1., 2., 3.], dtype=float32), 'b': [Array([1., 2., 3.], dtype=float32), 'a', Array([1., 2., 2.], dtype=float32)], 'kernel': Array([[2.]], dtype=float32), 'bias': Array([1.], dtype=float32)}}; value: {'name': 'Sequential_1', '_no_grad': {'c'}, 'a': Array([0., 0., 0.], dtype=float32), 'b': [Array([0., 0., 0.], dtype=float32), 'a', Array([0., 0., 0.], dtype=float32)], 'c': {'name': 'Dense_1', '_no_grad': set(), 'a': Array([1., 2., 3.], dtype=float32), 'b': [Array([1., 2., 3.], dtype=float32), 'a', Array([1., 2., 2.], dtype=float32)], 'kernel': Array([[2.]], dtype=float32), 'bias': Array([1.], dtype=float32)}}. I have hit this message with my previous approaches as well and I don't understand why this is happening. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Unfortunately, this kind of approach will not work. PyTree flattening needs to return the same number of leaves regardless of the type of the leaves: dynamically changing the number of leaves based on their Python type breaks the contract of PyTree flattening that is relied upon by transformations like Your best way forward here is probably to abandon the notion of relying on |
Beta Was this translation helpful? Give feedback.
Unfortunately, this kind of approach will not work. PyTree flattening needs to return the same number of leaves regardless of the type of the leaves: dynamically changing the number of leaves based on their Python type breaks the contract of PyTree flattening that is relied upon by transformations like
jit
. See #16170 for a related discussion.Your best way forward here is probably to abandon the notion of relying on
isinstance
checks during flattening. One alternative that might work if done carefully would be to determine at__init__
time what the flattening will look like, and store that information for use whentree_flatten
is called, and then be sure to avoid this__init__
logic during