-
Hello, I have recently started to use Jax and I wanted to apply it to an automatic differentiation problem. I have a function that takes three inputs: 2 classes and 1 scalar. I want to take the sensitivity of my output (6x6 matrix) with respect to an attribute (scalar) of one of my classes. At the moment to work with partial derivatives, I am passing a fourth argument, Cf (Cf which is a tensor, and the line that calculates it is commented out in MT method). The goal is to remove the input Cf and calculate the derivative with respect to Fiber.EL. Here is an example code:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The easiest way to do this would be to ensure your class is registered as a pytree, and then define a function with respect to its flattened contents and take the gradient with respect to the desired argument. That's admittedly not a very clean solution here; #10614 is a feature request for more expressive gradient argument specifications that might allow a cleaner solution in this case. |
Beta Was this translation helpful? Give feedback.
What I mean by "define a function with respect to its flattened contents" is defining a function that accepts arrays or scalars as arguments. For example:
It's not very convenient, but since there's currently no way to construct the gradient with respect to a class attribute, this is the best workaround.