Skip to content
Discussion options

You must be logged in to vote

If you have a function that truly operates independently across batches, you can do this by wrapping your jax.hessian call in a vmap. In your case it might look like this:

netf = lambda x, p: net.apply(params, x, p).sum(-1)
out = jax.vmap(jax.hessian(netf, argnums=0))(input, potential)
print(out.shape)
# (4, 1, 2, 1, 2)

Note that because both input and potential are batched, they both need to be arguments to the vmap-transformed function.
Additionally, in order for this to work, your code needs to be written in a way that correctly handles un-batched inputs. It looks like in your case that can be done by changing these lines:

        n_particle_original = x.shape[1]
        x = jnp.concat…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Binbose
Comment options

Answer selected by Binbose
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants