diff --git a/.travis.yml b/.travis.yml index e30a180bd..1aca57f52 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ install: # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - - pip install jax + - pip install -e git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax - pip install .[examples,test] - pip freeze diff --git a/examples/neutra.py b/examples/neutra.py index ccae859d6..fde3e4de1 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -36,9 +36,10 @@ from numpyro.infer.util import initialize_model, transformed_potential_energy -# XXX: upstream logsumexp throws NaN under fast-math mode + MCMC's progress_bar=True def logsumexp(x, axis=0): - return np.log(np.sum(np.exp(x), axis=axis)) + # TODO: remove when https://github.com/google/jax/pull/2260 merged upstream + x_max = lax.stop_gradient(np.max(x, axis=axis, keepdims=True)) + return np.log(np.sum(np.exp(x - x_max), axis=axis)) + x_max.squeeze(axis=axis) class DualMoonDistribution(dist.Distribution): diff --git a/examples/ode.py b/examples/ode.py index 339dd1505..6695f5b35 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -21,7 +21,7 @@ import matplotlib import matplotlib.pyplot as plt -from jax.experimental.ode import build_odeint +from jax.experimental.ode import odeint import jax.numpy as np from jax.random import PRNGKey @@ -33,21 +33,19 @@ matplotlib.use('Agg') # noqa: E402 -def dz_dt(z, t, alpha, beta, gamma, delta): +def dz_dt(z, t, theta): """ Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta` describes the interaction of two species. """ u = z[0] v = z[1] + alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3] du_dt = (alpha - beta * v) * u dv_dt = (-gamma + delta * u) * v return np.stack([du_dt, dv_dt]) -predator_prey_int = build_odeint(dz_dt, rtol=1e-5, atol=1e-3, mxstep=500) - - def model(N, y=None): """ :param int N: number of measurement times @@ -63,7 +61,7 @@ def model(N, y=None): dist.TruncatedNormal(low=0., loc=np.array([0.5, 0.05, 1.5, 0.05]), scale=np.array([0.5, 0.05, 0.5, 0.05]))) # integrate dz/dt, the result will have shape N x 2 - z = predator_prey_int(z_init, ts, *theta) + z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500) # measurement errors, we expect that measured hare has larger error than measured lynx sigma = numpyro.sample("sigma", dist.Exponential(np.array([1, 2]))) # measured populations (in log scale) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6c5dabf7c..5c0c45a80 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -37,7 +37,6 @@ from numpyro.distributions.transforms import AffineTransform, ExpTransform, InvCholeskyTransform, PowerTransform from numpyro.distributions.util import ( cholesky_of_inverse, - cumsum, lazy_property, matrix_to_tril_vec, promote_shapes, @@ -227,7 +226,7 @@ def __init__(self, scale=1., num_steps=1, validate_args=None): def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape walks = random.normal(key, shape=shape) - return cumsum(walks) * np.expand_dims(self.scale, axis=-1) + return np.cumsum(walks, axis=-1) * np.expand_dims(self.scale, axis=-1) @validate_sample def log_prob(self, value): diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 40090ca88..5f9d313b4 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -13,8 +13,6 @@ from numpyro.distributions import constraints from numpyro.distributions.util import ( - cumprod, - cumsum, get_dtype, matrix_to_tril_vec, signed_stick_breaking_tril, @@ -207,7 +205,7 @@ def __call__(self, x): def inv(self, y): # inverse stick-breaking - z1m_cumprod = 1 - cumsum(y * y) + z1m_cumprod = 1 - np.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width, @@ -224,7 +222,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) - z1m_cumprod = 1 - cumsum(y * y) + z1m_cumprod = 1 - np.cumsum(y * y, axis=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) @@ -384,7 +382,7 @@ class OrderedTransform(Transform): def __call__(self, x): z = np.concatenate([x[..., :1], np.exp(x[..., 1:])], axis=-1) - return cumsum(z) + return np.cumsum(z, axis=-1) def inv(self, y): x = np.log(y[..., 1:] - y[..., :-1]) @@ -457,7 +455,7 @@ def __call__(self, x): x = x - np.log(x.shape[-1] - np.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick z = _clipped_expit(x) - z1m_cumprod = cumprod(1 - z) + z1m_cumprod = np.cumprod(1 - z, axis=-1) pad_width = [(0, 0)] * x.ndim pad_width[-1] = (0, 1) z_padded = np.pad(z, pad_width, mode="constant", constant_values=1.) @@ -468,7 +466,7 @@ def __call__(self, x): def inv(self, y): y_crop = y[..., :-1] - z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny) + z1m_cumprod = np.clip(1 - np.cumsum(y_crop, axis=-1), a_min=np.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod x = np.log(y_crop / z1m_cumprod) return x + np.log(x.shape[-1] - np.arange(x.shape[-1])) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index d5e5d2669..54f47ce1b 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -4,7 +4,7 @@ from functools import update_wrapper import math -from jax import custom_transforms, defjvp, jit, lax, random, vmap +from jax import jit, lax, random, vmap from jax.dtypes import canonicalize_dtype from jax.lib import xla_bridge import jax.numpy as np @@ -167,7 +167,7 @@ def _categorical(key, p, shape): # this implementation is fast when event shape is small, and slow otherwise # Ref: https://stackoverflow.com/a/34190035 shape = shape or p.shape[:-1] - s = cumsum(p) + s = np.cumsum(p, axis=-1) r = random.uniform(key, shape=shape + (1,)) # FIXME: replace this computation by using binary search as suggested in the above # reference. A while_loop + vmap for a reshaped 2D array would be enough. @@ -312,24 +312,6 @@ def binary_cross_entropy_with_logits(x, y): return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y -@custom_transforms -def cumsum(x): - return np.cumsum(x, axis=-1) - - -defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1)) - - -@custom_transforms -def cumprod(x): - return np.cumprod(x, axis=-1) - - -# XXX this implementation does not address the case x=0, hence the result in that case will be nan -# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely -defjvp(cumprod, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans) - - def promote_shapes(*args, shape=()): # adapted from lax.lax_numpy if len(args) < 2 and not shape: @@ -418,7 +400,7 @@ def signed_stick_breaking_tril(t): # we omit the step of computing s = z * z_cumprod by using the fact: # y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod) z = r ** 2 - z1m_cumprod = cumprod(1 - z) + z1m_cumprod = np.cumprod(1 - z, axis=-1) z1m_cumprod_sqrt = np.sqrt(z1m_cumprod) pad_width = [(0, 0)] * z.ndim diff --git a/setup.py b/setup.py index 11946cada..9e6a7181f 100644 --- a/setup.py +++ b/setup.py @@ -34,8 +34,8 @@ author_email='npradhan@uber.com', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax>=0.1.57', - 'jaxlib>=0.1.37', + 'jax @ git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax', + 'jaxlib>=0.1.43', 'tqdm', ], extras_require={ diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index 984410634..f91c49ee7 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -7,7 +7,7 @@ from numpy.testing import assert_allclose import pytest -from jax import jacobian, lax, random, vmap +from jax import lax, random, vmap import jax.numpy as np from jax.scipy.special import expit, xlog1py, xlogy @@ -15,8 +15,6 @@ binary_cross_entropy_with_logits, categorical, cholesky_update, - cumprod, - cumsum, multinomial, poisson, vec_to_tril_matrix, @@ -33,36 +31,6 @@ def test_binary_cross_entropy_with_logits(x, y): assert_allclose(actual, expect, rtol=1e-6) -@pytest.mark.parametrize('shape', [ - (3,), - (5, 3), -]) -def test_cumsum_jac(shape): - rng_key = random.PRNGKey(0) - x = random.normal(rng_key, shape=shape) - - def test_fn(x): - return np.stack([x[..., 0], x[..., 0] + x[..., 1], x[..., 0] + x[..., 1] + x[..., 2]], -1) - - assert_allclose(cumsum(x), test_fn(x)) - assert_allclose(jacobian(cumsum)(x), jacobian(test_fn)(x)) - - -@pytest.mark.parametrize('shape', [ - (3,), - (5, 3), -]) -def test_cumprod_jac(shape): - rng_key = random.PRNGKey(0) - x = random.uniform(rng_key, shape=shape) - - def test_fn(x): - return np.stack([x[..., 0], x[..., 0] * x[..., 1], x[..., 0] * x[..., 1] * x[..., 2]], -1) - - assert_allclose(cumprod(x), test_fn(x)) - assert_allclose(jacobian(cumprod)(x), jacobian(test_fn)(x), atol=1e-7) - - @pytest.mark.parametrize('prim', [ xlogy, xlog1py, @@ -86,19 +54,6 @@ def test_binop_batch_rule(prim): assert_allclose(actual_bx_y[i], prim(bx[i], y)) -@pytest.mark.parametrize('prim', [ - cumsum, - cumprod, -]) -def test_unop_batch_rule(prim): - rng_key = random.PRNGKey(0) - bx = random.normal(rng_key, (3, 5)) - - actual = vmap(prim)(bx) - for i in range(3): - assert_allclose(actual[i], prim(bx[i])) - - @pytest.mark.parametrize('p, shape', [ (np.array([0.1, 0.9]), ()), (np.array([0.2, 0.8]), (2,)),