Is there a JAX function like torch.Tensor.masked_fill_
?
#9363
Answered
by
tomhennigan
ayaka14732
asked this question in
Q&A
-
>>> import numpy as np
>>> import torch
>>> a = torch.rand((3, 2, 5))
>>> a
tensor([[[0.6125, 0.2513, 0.4099, 0.8357, 0.6188],
[0.0068, 0.5963, 0.3633, 0.4969, 0.2056]],
[[0.7309, 0.3909, 0.8121, 0.9801, 0.0426],
[0.4047, 0.1876, 0.5262, 0.4997, 0.4828]],
[[0.5044, 0.8133, 0.0537, 0.4708, 0.0057],
[0.6625, 0.6482, 0.2043, 0.5460, 0.2341]]])
>>> mask = torch.tensor([[True, True, True, False, False],
... [True, True, False, False, False]])
>>> a.masked_fill_(mask, np.NINF)
tensor([[[ -inf, -inf, -inf, 0.8357, 0.6188],
[ -inf, -inf, 0.3633, 0.4969, 0.2056]],
[[ -inf, -inf, -inf, 0.9801, 0.0426],
[ -inf, -inf, 0.5262, 0.4997, 0.4828]],
[[ -inf, -inf, -inf, 0.4708, 0.0057],
[ -inf, -inf, 0.2043, 0.5460, 0.2341]]]) |
Beta Was this translation helpful? Give feedback.
Answered by
tomhennigan
Jan 28, 2022
Replies: 1 comment 3 replies
-
I'm not aware of a built in, but you can do this with broadcast + select: >>> def masked_fill(mask, a, fill):
... return jax.lax.select(mask, a, jax.lax.broadcast(fill, a.shape))
>>> key = jax.random.PRNGKey(42)
>>> key1, key2 = jax.random.split(key)
>>> a = jax.random.normal(key1, (4, 4))
>>> mask = jax.random.bernoulli(key2, shape=a.shape)
>>> masked_fill(mask, a, jnp.inf)
DeviceArray([[ inf, inf, -1.9382184 , -0.9676806 ],
[-0.3920406 , 0.6062071 , 0.37990323, inf],
[ 1.3282976 , 1.1882836 , inf, 1.2611461 ],
[ inf, inf, 0.09786294, inf]], dtype=float32) |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
ayaka14732
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm not aware of a built in, but you can do this with broadcast + select: