Skip to content
Discussion options

You must be logged in to vote

I think I understand what is happening in your code and if I'm not mistaken, I don't think gradients are particularly defined for this setup. I tried distilling the pattern you have in your code down:

def mh(key, w, x, y):
  u = jrnd.uniform(key)
  return jax.lax.cond(u < jax.nn.sigmoid(w), lambda _: x, lambda _: y, None)
jax.grad(mh, argnums=1)(jrnd.PRNGKey(4), 0., 1., 2.)

In this example, the second argument w is used to compute the probability u is compared against. Unfortunately, though, the outputs x and y are disconnected from w, that is, regardless of the branch chosen in the cond we return a value that's independent of the value of w. This will always return a gradient of 0.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@emprice
Comment options

Answer selected by emprice
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants