-
I wonder if there is an efficient method to compute the Jacobian determinant for a function |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Thanks for the question! My guess is it's not possible in general to get better asymptotic efficiency. Consider the case where f is just a general linear transformation, i.e. Intuitively, evaluating the determinant of the dense Jacobian is doing the right thing for the general case: we want to know the volume of a parallelepiped which is the image of an axis-aligned unit cube under the locally-linearized function. Finding the image of each standard basis vector is exactly what But clearly we're leaving something on the table for structured functions. For example, what if our function applied elementwise, like One way to think about taking advantage of structure is to break the function down into a composition of primitive functions. So long as we have a formula for computing the Jacobian determinant of So you could imagine writing a custom jaxpr interpreter along with a table of Jacobian determinant rules to do this. That interpreter would be able to exploit structure in the function f to get better asymptotic efficiency than the det-of-jacfwd approach in some cases where it's possible. Actually, it turns out that @sharadmv wrote that custom jaxpr interpreter tutorial as he was learning about jaxprs, and he was particularly interested in the function inverse case because he cared about something along these lines. He was working on probabilistic programming, and automatically computing change-of-volume quantities is helpful in computing reparameterized densities for MCMC methods. For that reason he was interested in inverse-log-det-Jacobians. Ultimately he went on to build Oryx, and Oryx likely has tools to compute Jacobian determinants in the structure-exploiting compositional way described here. If you're looking for a library to do these kinds of computations, check it out! WDYT? |
Beta Was this translation helpful? Give feedback.
-
Hi @nalzok, I'm facing a similar issue (running OOM when trying to do |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
My guess is it's not possible in general to get better asymptotic efficiency. Consider the case where f is just a general linear transformation, i.e.
f = lambda x: jnp.dot(A, x)
for given dense/unstructuredA
. Computing the Jacobian determinant of this function is exactly computing the determinant ofA
, and we need to pay d^3 for that.Intuitively, evaluating the determinant of the dense Jacobian is doing the right thing for the general case: we want to know the volume of a parallelepiped which is the image of an axis-aligned unit cube under the locally-linearized function. Finding the image of each standard basis vector is exactly what
jacfwd
does, as efficiently…