-
I am working on a code that is supposed to implement MH sampling at one step. I've seen a few threads about MCMC sampling in the Issues, but nothing looked like it would solve my problem, which is that all my gradients are zero, even though I don't think they should be. I'm really stuck and could use another pair of eyes, particularly from people who think in JAX. I apologize for the length, but this is the smallest example I have that doesn't work the way I expect. import jax
import jax.numpy as jnp
from jax import random as jrnd
import jax.experimental.optimizers as jopt
# config values
N = 10 # input layer size
M = 20 # hidden layer size
S = 1000 # number of samples to draw
# initial random key
key = jrnd.PRNGKey(0)
# initialize the parameters to random
key, *subkeys = jrnd.split(key, 5)
sigma = jrnd.bernoulli(subkeys[0], shape=(N,))
a0 = jrnd.normal(subkeys[1], shape=(N,))
b0 = jrnd.normal(subkeys[2], shape=(M,))
W0 = jrnd.normal(subkeys[3], shape=((N, M)))
a0 = a0 / jnp.sum(a0**2)
b0 = b0 / jnp.sum(b0**2)
W0 = W0 / jnp.sum(W0**2)
weights0 = (a0, b0, W0)
@jax.jit
def logPsi(sigma, weights):
a, b, W = weights
s = 2 * sigma - 1
F = 2. * jnp.cosh(b + jnp.dot(s, W))
return jnp.dot(s, a) + jnp.sum(jnp.log(F)) # this is our log-probability function
# a static constant, we just need it to shuffle later
to_shuffle = jnp.concatenate((jnp.ones((1,), dtype=bool), jnp.zeros((N - 1,), dtype=bool)))
@jax.jit
def sample(key, sigma, weights):
# metropolis sampling step: flip a value in sigma, then compute probabilities and
# decide to accept or reject
key, *subkeys = jrnd.split(key, 3)
perm = jrnd.permutation(subkeys[0], to_shuffle)
newsigma = perm ^ sigma
old = logPsi(sigma, weights)
new = logPsi(newsigma, weights)
prob = jnp.exp(new - old)
u = jrnd.uniform(subkeys[1])
# the weights *should* come into play here, where we decide if a random
# uniform value u is less than the acceptance probability. since those probabilities
# are calculated via the weights, this shouldn't have zero gradient!
return jax.lax.cond(u < prob, lambda op: op[1], lambda op: op[0], (sigma, newsigma))
@jax.jit
def energy(sigma):
# fake energy function -- just for testing
return jnp.sum(sigma, dtype=jnp.float32)
@jax.jit
def inner(carry, key):
# body function for the scan below: just sample once and return
sigma, weights = carry
sigma = sample(key, sigma, weights)
return (sigma, weights), energy(sigma)
@jax.jit
def objective(weights, key, sigma):
# loss function: return the mean energy (a dummy quantity at this point)
key, *subkeys = jrnd.split(key, S + 1)
subkeys = jnp.array(subkeys)
(sigma, _), Es = jax.lax.scan(inner, (sigma, weights), subkeys)
return jnp.mean(Es), (key, sigma)
# below is boilerplate optimization code
valgrad = jax.value_and_grad(objective, has_aux=True)
lr = 1e-2 # learning rate
opt_init, opt_update, get_params = jopt.adam(lr)
opt_state = opt_init(weights0)
def step(step, opt_state, key, sigma):
value_aux, grads = valgrad(get_params(opt_state), key, sigma)
# if we print gradients here, all zeros
opt_state = opt_update(step, grads, opt_state)
return value_aux, opt_state
for i in range(10):
(value, (key, sigma)), opt_state = step(i, opt_state, key, sigma) Since we're minimizing energy and energy is just a sum of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think I understand what is happening in your code and if I'm not mistaken, I don't think gradients are particularly defined for this setup. I tried distilling the pattern you have in your code down: def mh(key, w, x, y):
u = jrnd.uniform(key)
return jax.lax.cond(u < jax.nn.sigmoid(w), lambda _: x, lambda _: y, None)
jax.grad(mh, argnums=1)(jrnd.PRNGKey(4), 0., 1., 2.) In this example, the second argument |
Beta Was this translation helpful? Give feedback.
I think I understand what is happening in your code and if I'm not mistaken, I don't think gradients are particularly defined for this setup. I tried distilling the pattern you have in your code down:
In this example, the second argument
w
is used to compute the probabilityu
is compared against. Unfortunately, though, the outputsx
andy
are disconnected fromw
, that is, regardless of the branch chosen in thecond
we return a value that's independent of the value ofw
. This will always return a gradient of 0.