-
Given this trivial pytrees reproducer: import jax
import numpy as np
jax.config.update("jax_platform_name", "cpu")
class SpecialArray:
def __init__(self, array):
self._array = array
def __mul__(self, other):
assert isinstance(other, (int, float))
return SpecialArray(self._array * other)
def __sub__(self, other):
assert isinstance(other, (int, float))
return SpecialArray(self._array - other)
def __len__(self):
return len(self._array)
def flatten_special(special):
assert not isinstance(special._array, SpecialArray)
return (special._array,), None
def unflatten_special(aux_data, leaves):
return SpecialArray(leaves[0])
jax.tree_util.register_pytree_node(
SpecialArray,
flatten_special,
unflatten_special,
)
def func(x):
return x * 2 - 1
array = np.array([[0, 1, 2], [3, 4, 5.0]])
jac = jax.jacfwd(func)(array)
special = SpecialArray(array)
jac_special = jax.jacfwd(func)(special)
print(jac_special) I'm observing that the
Why is JAX giving me my own node type as a leaf in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The jacobian, as an operation, computes the derivative of each output element with respect to the input. So, for example, if you have a function that accepts an array of length import jax
import jax.numpy as jnp
def f(x): # input vector of length n
return jnp.append(x, x.mean()) # output vector of length m = n + 1
x = jnp.array([1., 2.]) # n = 2
jax.jacrev(f)(x) # output is length (m, n) = (3, 2)
# DeviceArray([[1. , 0. ],
# [0. , 1. ],
# [0.5, 0.5]], dtype=float32) You can think of this as a length-3 array, in which each element is a length-2 array giving the gradient of that single output element with respect to the full input. Now how does this relate to pytrees? If your function inputs a pytree and outputs a tuple, your jacobian will be a tuple containing a pytree for each output value: from typing import NamedTuple
class MyType(NamedTuple):
x: jnp.ndarray
y: jnp.ndarray
def f(a):
return (a.x + a.y, a.x - a.y)
a = MyType(jnp.float32(1), jnp.float32(2))
jax.jacrev(f)(a)
# (MyType(x=DeviceArray(1., dtype=float32), y=DeviceArray(1., dtype=float32)),
# MyType(x=DeviceArray(1., dtype=float32), y=DeviceArray(-1., dtype=float32))) But what if instead of returning a tuple, you return another pytree? Well in that case, you still get a pytree per output, but instead of those pytrees being embedded in a tuple, those pytrees are embedded in a pytree: def f(a):
return MyType(a.x + a.y, a.x - a.y)
a = MyType(jnp.float32(1), jnp.float32(2))
jax.jacrev(f)(a)
# MyType(x=MyType(x=DeviceArray(1., dtype=float32), y=DeviceArray(1., dtype=float32)),
# y=MyType(x=DeviceArray(1., dtype=float32), y=DeviceArray(-1., dtype=float32))) The jacobian computes the gradient of each output value with respect to the full input. If you think about it this way, it's clear that when computing the jacobian of a function which maps a pytree to a pytree, a nested pytree is the most logical representation. Does that make sense? |
Beta Was this translation helpful? Give feedback.
The jacobian, as an operation, computes the derivative of each output element with respect to the input. So, for example, if you have a function that accepts an array of length
n
and returns an array of lengthm
, you getm
derivatives, each of lengthn
. More concretely, a function that maps 2 inputs to 3 outputs will have a jacobian of shape(3, 2)
:You can think of …