Gradient plot is ok? #8146
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 6 replies
-
I don't think your function is returning the gradients that you think it is: import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
f = lambda x, y: -(x ** 2 + y ** 2)
x, y = jnp.mgrid[-5:5:20j,-5:5:20j]
z = f(x, y)
grad_x = jax.vmap(jax.vmap(jax.grad(f, argnums=0)))(x, y)
grad_y = jax.vmap(jax.vmap(jax.grad(f, argnums=1)))(x, y)
fig, ax = plt.subplots(figsize=(8, 6))
ax.contourf(x, y, z, levels=100)
ax.quiver(x, y, grad_x, grad_y) |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, here is a snippet on the usage of jax.vmap and jax.grad to get a gradient flow as well as contour plots on top of each other
The result looks strange to me

as I would have expected the arrow to be orthogonal to contour levels. I'm wrong somewhere ?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions