Efficient gradient computation with masks #18490
Unanswered
mohamad-amin
asked this question in
Q&A
Replies: 1 comment
-
Here's an example where I attempt to do that. I don't know how greatly this will end up being optimized, but this does what you want by stopping gradient flow internally: import jax
from jax import numpy as np
from jaxtyping import Array, Scalar
key = jax.random.key(0)
x = np.linspace(1, 3, 5)
(indices,) = np.indices(x.shape)
mask = jax.random.choice(key, indices, shape=(4,), replace=False)
def model(x: Array) -> Array:
return x**2
def cost(x: Array) -> Scalar:
return np.sum(x**2)
def f(x: Array, mask: Array | None = None) -> Scalar:
if mask is not None:
x = x.at[mask].set(jax.lax.stop_gradient(x[mask]))
return cost(model(x))
print(jax.value_and_grad(f)(x))
print(jax.value_and_grad(f)(x, mask)) Output: (Array(142.125, dtype=float32), Array([ 4. , 13.5, 32. , 62.5, 108. ], dtype=float32))
(Array(142.125, dtype=float32), Array([ 0., 0., 32., 0., 0.], dtype=float32)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey,
Consider I have the following script which computes the loss and its gradient on an arbitrary model only on a portion of data identified by the
mask
matrix, whose entries are1
or0
. In this case, the entries(i, j)
wheremask[i, j] = 0
don't contribute anything to the gradient. I'm wondering if there's an efficient way of avoiding computing gradients for these entries. For instance, ifmask
is highly sparse (mask.sum() << np.prod(mask.shape)
) the gradient computation could be much more efficient if we ignore the zero gradients.Any help or feedback would be appreciated. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions