How to shape Pytree's correctly for use with and after vjp
.
#9980
-
I'm trying to implement a Gauss-Newton vector product abstractly according to The simplest implementation is: def hvp(fun, p, t):
return jax.jvp(jax.jacrev(fun, argnums=0), p, t)
def gvp(inner_fun, outer_fun, p, t):
y, f_lin = jax.linearize(f, *p)
dz, Gv = hvp(lambda a: outer_fun(inner_fun(a)), p, t)
return y, dz, Gv
# Test code
f = lambda x: jnp.sin(x)
g = lambda x: jnp.square(x) / 2
z = lambda x: g(x).sum()
s = jnp.pi / 8
print(gvp(f, g, (s, ), (s, ))[2])
print(gvp(f, z, (s, ), (s, ))[2])
print(gvp(f, g, (jnp.asarray([s]), ), (jnp.asarray([s]), ))[2])
print(gvp(f, z, (jnp.asarray([s]), ), (jnp.asarray([s]), ))[2])
>>> 0.27768016
>>> 0.27768016
>>> [[0.27768016]]
>>> [0.27768016]
# Test with some variable data containers
s = jnp.tile(s, 3)
print(gvp(f, g, (s, ), (s, ))[2])
print(gvp(f, z, (s, ), (s, ))[2])
print(gvp(f, g, (jnp.asarray([s]), ), (jnp.asarray([s]), ))[2])
print(gvp(f, z, (jnp.asarray([s]), ), (jnp.asarray([s]), ))[2])
>>> [[0.27768016 0. 0. ]
[0. 0.27768016 0. ]
[0. 0. 0.27768016]]
>>> [0.27768016 0.27768016 0.27768016]
>>> [[[[0.27768016 0. 0. ]]
[[0. 0.27768016 0. ]]
[[0. 0. 0.27768016]]]]
>>> [[0.27768016 0.27768016 0.27768016]] It seems that this method almost always works (at least for my testing). However, this is not really efficient. I don't want to call the Hessian of a linearized network when I know this will only contain zeros. Typically, Hence why the composition of def gvp(inner_fun, outer_fun, p, t):
y, Jt = jax.jvp(inner_fun, p, t)
dz, HJt = hvp(outer_fun, (y,), (Jt,))
y, vjp_fun = jax.vjp(inner_fun, *p)
Gv = vjp_fun(HJt)
return y, dz, Gv
# Test code
s = jnp.pi / 8
print(gvp(f, g, (s, ), (s, ))[2])
print(gvp(f, z, (s, ), (s, ))[2])
>>> (DeviceArray(0.33518964, dtype=float32, weak_type=True),)
>>> (DeviceArray(0.33518964, dtype=float32, weak_type=True),)
s = jnp.tile(s, 3)
print(gvp(f, g, (s, ), (s, ))[2]) # ValueError: Shape of cotangent input to vjp pullback ... The first thing that is not going as expected is that the The ValueError should be fixed by def gvp_vmapped(inner_fun, outer_fun, p, t):
y, Jt = jax.jvp(inner_fun, p, t)
dz, HJt = hvp(outer_fun, (y,), (Jt,))
HJt = jax.tree_map(jnp.atleast_1d, HJt) # vmap hack
y, vjp_fun = jax.vjp(inner_fun, *p)
Gv = vmap(vjp_fun)(*jax.tree_leaves(HJt))
return y, dz, Gv
# Test code
s = jnp.pi / 8
print(gvp_vmapped(f, g, (s, ), (s, ))[2])
print(gvp_vmapped(f, z, (s, ), (s, ))[2])
>>> (DeviceArray([0.33518964], dtype=float32, weak_type=True),)
>>> (DeviceArray([0.33518964], dtype=float32, weak_type=True),)
s = jnp.tile(s, 3)
print(gvp_vmapped(f, g, (s, ), (s, ))[2])
>>> (DeviceArray([[0.33518964, 0. , 0. ],
[0. , 0.33518964, 0. ],
[0. , 0. , 0.33518964]], dtype=float32, weak_type=True),)
print(gvp_vmapped(f, z, (s, ), (s, ))[2]) # ValueError: Shape of cotangent input to vjp ... Now, multiple things are going wrong... The scalar functions yield a result that is now a (1,) array inside of a tuple, and the So I'm asking for any guidance or tips on how to correctly manipulate the shapes and sizes, such that my custom |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 14 replies
-
def gvp(inner_fun, outer_fun, p, t):
y, f_lin = jax.linearize(f, *p)
dz, Gv = hvp(lambda a: outer_fun(inner_fun(a)), p, t)
return y, dz, Gv I think it should be def gvp(inner_fun, outer_fun, p, t):
y, f_lin = jax.linearize(inner_fun, *p)
dz, Gv = hvp(lambda a: outer_fun(y + f_lin(jax.tree_map(lambda x, y: x - y, a, p))), p, t)
# or hvp(lambda a: outer_fun(y + f_lin(a)), jax.tree_map(lambda x: jnp.zeros_like(x), p), t)
return y, dz, Gv otherwise Emmm, are you sure it is not efficient? |
Beta Was this translation helpful? Give feedback.
-
@joeryjoery Emmm, I don't know why we need a nested import jax
import jax.numpy as jnp
def nested_vmap(fun, n: int):
for _ in range(n):
fun = jax.vmap(fun)
return fun
def gvp(inner_fun, outer_fun, p_in, t_in):
# p_in: pytree_0
# t_in: pytree_1
# inner_fun: pytree_1 -> pytree_2
# outer_fun: pytree_1 -> pytree_2
p_out, f_l = jax.linearize(inner_fun, p_in) # (pytree_1), (pytree_0 -> pytree_1)
f_lt_tuple = jax.linear_transpose(f_l, p_in) # pytree_1 -> pytree_0
f_lt = lambda x: f_lt_tuple(x)[0] # primals tuple only contain one primal
Jt = f_l(t_in) # pytree_1
d_outer, HJt = jax.jvp(jax.jacrev(outer_fun, argnums=0), (p_out,), (Jt,))
# pytree_2(pytree_1), pytree_2(pytree_1) with prepended shape leaves
shapes = jax.eval_shape(outer_fun, p_out) # pytree_2
Gt = jax.tree_map(lambda s, h: nested_vmap(f_lt, len(s.shape))(h), shapes, HJt) # h: pytree_1
return p_out, d_outer, Gt
def f(x):
return x
def g(x):
return jax.tree_map(lambda x: x ** 2, x)
x = (jnp.ones((2,2)), jnp.ones((2,2)))
print(gvp(f, g, x, x)[2]) Output: ((DeviceArray([[[[2., 0.],
[0., 0.]],
[[0., 2.],
[0., 0.]]],
[[[0., 0.],
[2., 0.]],
[[0., 0.],
[0., 2.]]]], dtype=float32),
DeviceArray([[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]], dtype=float32)),
(DeviceArray([[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]], dtype=float32),
DeviceArray([[[[2., 0.],
[0., 0.]],
[[0., 2.],
[0., 0.]]],
[[[0., 0.],
[2., 0.]],
[[0., 0.],
[0., 2.]]]], dtype=float32))) |
Beta Was this translation helpful? Give feedback.
@joeryjoery Emmm, I don't know why we need a nested
tree_map
here. (Assume: pack all arguments into single pytree).