-
A simple MWE is included below, which computes up to the third derivative of sin(x) for an array of values of x. I accomplish this with jacfwd and vmap. There are no issues up to the second derivative and the results when plotted appear correct. I get this error with the third derivative:
and am a bit unsure what to make of it. It says dfdx2 has no shape when I would expect it to be N x 1 x 1 (which is confirmed when dfdx2 is evaluated at x... am I going about computing the derivatives in a way that doesn't make sense in jax?
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The input argument is From your example, it's not entirely clear to me what you're trying to compute (what exactly is a triply nested vmapped jacobian? I'm honestly not sure). That said, if you want to compute the element-wise third derivative of a scalar function applied to a 2D input, one option might look like this: from jax import vmap, grad
df3 = grad(grad(grad(mySin)))
out = vmap(vmap(df3))(x)
print(out.shape)
# (100, 1) We use Does that answer your question? |
Beta Was this translation helpful? Give feedback.
vmap
refers to the axis of the input arrays, not to the function that is mapped. So the relevant axes here are not anything associated withdfdx2
, but rather the axes of the input argumentx
.The input argument is
x
, which is of shape[100, 1]
. The firstvmap
maps over the first axis of length 100; the second nestedvmap
maps over the second axis of length 1; the third nestedvmap
doesn't have any axis to map over, therefore it raises this error.From your example, it's not entirely clear to me what you're trying to compute (what exactly is a triply nested vmapped jacobian? I'm honestly not sure). That said, if you want to compute the element-wise third derivative of a scalar function app…