Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! When you mention Dead Code Ellimination, are you talking about eliminating computations over masked sections of the array? I'm not sure whether this is possible.

To strip this down and be concrete, I think essentially what you're asking is whether you can write a program like this:

import jax.numpy as jnp
import jax

key1, key2 = jax.random.split(random.PRNGKey(1701))

x = jax.random.uniform(key1, (10,))
mask = jax.random.randint(key1, (10,), 0, 2).astype(bool)

def f(x, mask):
  return jnp.where(mask, jnp.sin(x), 0)

f(x, mask)

and when running it depend on XLA not computing jnp.sin on the entries that will eventually be zeroed out. Is that correct?

I took a…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@pharringtonp19
Comment options

Answer selected by pharringtonp19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants