Replies: 1 comment 3 replies
-
I may not be understanding your question, but given a function with two inputs and one output, import jax
def f(x, y):
return x * y
df_dx = jax.grad(f, argnums=0)(1.0, 2.0)
print(df_dx) # 2.0 I can't think of any more efficient way to do this with JAX. |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Consider a function with two inputs,$f(x, y): \mathbb{R}^2 \to \mathbb{R}$ .$x$ , i.e., $\frac{\partial f}{\partial x}$ , at a certain point $(x, y)$ .
I want to calculate the partial derivative with respect to
Using grad(f) and select the first element can definitely do the job, but is there a more efficient way?$\frac{\partial f}{\partial y}$ .
Since grad(f) calculate the whole gradient vector, but what I want is one value of it, and I don't need the result of
Beta Was this translation helpful? Give feedback.
All reactions