Defining new JAX primitives: looking for help understanding vjp rule arg shapes #19884
Replies: 2 comments
-
Hi - it looks like the array shapes here are of the expected shape: you define Backing up, though, I think the solution in the post you link to is probably not the right template for you to follow here, because it uses non-public APIs, and because Primitives are a pretty heavy solution for the problem you have. You might have more luck instead using an approach based on |
Beta Was this translation helpful? Give feedback.
-
(I edited my original post because I left those @jakevdp Thank you for pointing out the pure callback option, but perhaps my question is more fundamental: how to write the vjp itself. In my understanding, the vjp is the transpose of a jacobian times a cotangent vector, Even if I take the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
Similar to this post, I wish to wrap an external library function with jax primitives. The code provided in the linked solution to the post mostly makes sense, but it uses a scalar function as an example while my function is multidimensional. I am having trouble understanding the shapes that my arrays need to take when defining the vjp rule.
Here is a working pure jax minimal implementation of my function (for comparison purposes).
Below, I import a wrapped C++ function that does the same thing from the basix library and follow the code from the previous post, but I run into trouble defining the vjp rule. I find the shapes of the cotangent and my jacobian don't match. I would very much appreciate it if someone could explain the shape of the cotangent (e.g where do the outer dimensions come from in jax) and how to write the transpose correctly!)
Thank you,
Olek
Beta Was this translation helpful? Give feedback.
All reactions