-
Hi, I'm trying to compute gradient of a function that run on data which contains nan values and use jnp.nanstd. I run the following experiment: rng = jax.random.PRNGKey(42)
x = jax.random.normal(rng, (10,))
_, rng = jax.random.split(rng)
params = {"w": jax.random.normal(rng, (10,))}
# put some nan values
x = jax.ops.index_update(x, 0, jnp.nan) Then run: def test(params, x):
y = params["w"] * x
score = jnp.nanstd(y)
return score
jax.value_and_grad(test)(params, x) This gives:
However: def test(params, x):
y = params["w"] * x
mean = jnp.nanmean(y)
mean2 = jnp.nanmean(y ** 2)
score = jnp.sqrt(mean2 - mean ** 2)
return score
jax.value_and_grad(test)(params, x) gives:
Finally, the third way to do: def test(params, x):
y = params["w"] * x
mean = jnp.nanmean(y)
var = jnp.nanmean((y - mean) ** 2)
score = jnp.sqrt(var)
return score
jax.value_and_grad(test)(params, x) gives:
Is this a bug in jnp.nanstd ? Do you have some recommendations for such use cases? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 7, 2021
Replies: 2 comments
-
Thanks for the report – I think this is a bug, so I'm going to convert this discussion to an issue: see #8128. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
jakevdp
-
Thanks for taking a look ! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the report – I think this is a bug, so I'm going to convert this discussion to an issue: see #8128.