Replies: 1 comment 3 replies
-
Thanks for the question! Actually, the autodiff cookbook talks about reverse-over-reverse too: that's the first example in the section you linked, To compute the Hessian of f = f_2 \circ f_1, you can certainly compute Jacobians and Hessians of f_2 and f_1 at the appropriate points and then contract them (ie multiply them) appropriately. But that's just the second-order version of computing the Jacobian of f by computing the Jacobians of f_2 and f_1 and multiplying. That's not usually a good idea because forming dense Jacobians that way typically throws away sparsity structure, represented directly by the program as data dependence (i.e. not all outputs of f_1 depend on all inputs to f_1, and not all outputs of f_2 depend on all inputs to f_2). I'm not sure exactly what you mean by a 'pass' here, but we can't use a single application of forward-mode because we've got to compute second derivatives of primitives somehow. That is, if the first derivative of the primitive g_0 involves applications of the primitive g_1, and the first derivative of g_1 involves applications of the primitive g_2, then we've got to generate g_2 somehow. In other words, if we take f_1 and f_2 to be primitives in your example, how would a single application of forward-mode compute H_{f_1} and H_{f_2}?
I'm not sure exactly what you mean by 'symbolic derivatives', but indeed you can use Maybe it's helpful to write out these functions in Python-like syntax, something like: def jvp(sin) x xdot:
y = sin(x)
ydot = cos(x) * xdot
return y, ydot
def lin(sin) x:
y = sin(x)
cos_x = cos(x)
return y, lambda xdot: cos_x * xdot
def vjp(sin) x:
y = sin(x)
cos_x = cos(x)
return y, lambda ybar: cos_x * ybar
def jvp(f . g) x xdot:
y, ydot = jvp(g)(x, xdot)
z, zdot = jvp(f)(y, ydot)
return z, zdot
def lin(f . g) x:
y, g_lin = lin(g)(x)
z, f_lin = lin(f)(y)
return z, lambda xdot: f_lin(g_lin(xdot))
def vjp(f . g) x:
y, g_vjp = vjp(g)(x)
z, f_vjp = lin(f)(y)
return z, lambda zbar: g_vjp(f_vjp(zbar)) WDYT? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I understand how to use JAX to compute the Hessian using automatic differentiation; however, I am having difficulty understanding how it works. In particular, I don't understand why we need two
passesapplications of automatic differentiation (i.e., reverse then forward or forward then reverse)*.Assume that$f$ is defined as a composition of functions:
$$f=f_2 \circ f_1$$ $f_1:\mathbb{R}^n \rightarrow \mathbb{R}^{m_1}$ and $f_2:\mathbb{R}^{m_1} \rightarrow \mathbb{R}^n$ . We can compute the Jacobian $J_f(x)$ using the chain rule:
$$J_f(x) = J_{f_2}(f_1(x))J_{f_1}(x)$$
where
and we can also compute the Hessian$H_f(x)$ :
$$H_f(x) = J_{f_1}(x)^TH_{f_2}(f_1(x))J_{f_1}(x) + J_{f_2}(f_1(x))H_{f_1}(x)$$
The JAX Autodiff Cookbook describes two methods for computing the hessian using auto differentiation:
Why can't we use a single forward mode$f_1(x)$ , $J_{f_1}(x)$ , and $H_{f_1}(x)$ in the same way we compute $f_1(x)$ and $J_{f_1}(x)$ for the Jacobian? We can then compute $f_2(f_1(x))$ , $J_{f_2}(f_1(x))$ , and $H_{f_2}(f_1(x))$ , and then put everything together.
passapplication to compute* If, for example, we choose to do forward mode and then reverse mode, then the output of the forward mode must include
symbolic derivativessome function, not just the Jacobian at a specific point. Otherwise, how would the reverse mode work? Perhaps I am missing something, so it would be helpful if someone could provide a simple example.Beta Was this translation helpful? Give feedback.
All reactions