-
I have a neural network function that maps
I want to take the hessian of the function with respect to the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
If you have a function that truly operates independently across batches, you can do this by wrapping your netf = lambda x, p: net.apply(params, x, p).sum(-1)
out = jax.vmap(jax.hessian(netf, argnums=0))(input, potential)
print(out.shape)
# (4, 1, 2, 1, 2) Note that because both n_particle_original = x.shape[1]
x = jnp.concatenate([x, potential], axis=1) to this: n_particle_original = x.shape[-2]
x = jnp.concatenate([x, potential], axis=-2) |
Beta Was this translation helpful? Give feedback.
If you have a function that truly operates independently across batches, you can do this by wrapping your
jax.hessian
call in avmap
. In your case it might look like this:Note that because both
input
andpotential
are batched, they both need to be arguments to thevmap
-transformed function.Additionally, in order for this to work, your code needs to be written in a way that correctly handles un-batched inputs. It looks like in your case that can be done by changing these lines: