Skip to content
Discussion options

You must be logged in to vote

The jacobian, as an operation, computes the derivative of each output element with respect to the input. So, for example, if you have a function that accepts an array of length n and returns an array of length m, you get m derivatives, each of length n. More concretely, a function that maps 2 inputs to 3 outputs will have a jacobian of shape (3, 2):

import jax
import jax.numpy as jnp

def f(x):  # input vector of length n
  return jnp.append(x, x.mean())  # output vector of length m = n + 1

x = jnp.array([1., 2.])  # n = 2
jax.jacrev(f)(x)  # output is length (m, n) = (3, 2)
# DeviceArray([[1. , 0. ],
#              [0. , 1. ],
#              [0.5, 0.5]], dtype=float32)

You can think of …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@agoose77
Comment options

Answer selected by agoose77
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants