Skip to content
Discussion options

You must be logged in to vote

I'm not aware of a built in, but you can do this with broadcast + select:

>>> def masked_fill(mask, a, fill):
...   return jax.lax.select(mask, a, jax.lax.broadcast(fill, a.shape))

>>> key = jax.random.PRNGKey(42)
>>> key1, key2 = jax.random.split(key)
>>> a = jax.random.normal(key1, (4, 4))
>>> mask = jax.random.bernoulli(key2, shape=a.shape)
>>> masked_fill(mask, a, jnp.inf)
DeviceArray([[        inf,         inf, -1.9382184 , -0.9676806 ],
             [-0.3920406 ,  0.6062071 ,  0.37990323,         inf],
             [ 1.3282976 ,  1.1882836 ,         inf,  1.2611461 ],
             [        inf,         inf,  0.09786294,         inf]],            dtype=float32)

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@ayaka14732
Comment options

@hawkinsp
Comment options

@ayaka14732
Comment options

Answer selected by ayaka14732
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants