Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ install:
# Keep track of pyro-api master branch
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
- pip install jaxlib
- pip install jax
- pip install -e git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax
- pip install .[examples,test]
- pip freeze

Expand Down
5 changes: 3 additions & 2 deletions examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from numpyro.infer.util import initialize_model, transformed_potential_energy


# XXX: upstream logsumexp throws NaN under fast-math mode + MCMC's progress_bar=True
def logsumexp(x, axis=0):
return np.log(np.sum(np.exp(x), axis=axis))
# TODO: remove when https://github.com/google/jax/pull/2260 merged upstream
x_max = lax.stop_gradient(np.max(x, axis=axis, keepdims=True))
return np.log(np.sum(np.exp(x - x_max), axis=axis)) + x_max.squeeze(axis=axis)


class DualMoonDistribution(dist.Distribution):
Expand Down
10 changes: 4 additions & 6 deletions examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import matplotlib
import matplotlib.pyplot as plt

from jax.experimental.ode import build_odeint
from jax.experimental.ode import odeint
import jax.numpy as np
from jax.random import PRNGKey

Expand All @@ -33,21 +33,19 @@
matplotlib.use('Agg') # noqa: E402


def dz_dt(z, t, alpha, beta, gamma, delta):
def dz_dt(z, t, theta):
"""
Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
describes the interaction of two species.
"""
u = z[0]
v = z[1]
alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3]
du_dt = (alpha - beta * v) * u
dv_dt = (-gamma + delta * u) * v
return np.stack([du_dt, dv_dt])


predator_prey_int = build_odeint(dz_dt, rtol=1e-5, atol=1e-3, mxstep=500)


def model(N, y=None):
"""
:param int N: number of measurement times
Expand All @@ -63,7 +61,7 @@ def model(N, y=None):
dist.TruncatedNormal(low=0., loc=np.array([0.5, 0.05, 1.5, 0.05]),
scale=np.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = predator_prey_int(z_init, ts, *theta)
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(np.array([1, 2])))
# measured populations (in log scale)
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from numpyro.distributions.transforms import AffineTransform, ExpTransform, InvCholeskyTransform, PowerTransform
from numpyro.distributions.util import (
cholesky_of_inverse,
cumsum,
lazy_property,
matrix_to_tril_vec,
promote_shapes,
Expand Down Expand Up @@ -227,7 +226,7 @@ def __init__(self, scale=1., num_steps=1, validate_args=None):
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
walks = random.normal(key, shape=shape)
return cumsum(walks) * np.expand_dims(self.scale, axis=-1)
return np.cumsum(walks, axis=-1) * np.expand_dims(self.scale, axis=-1)

@validate_sample
def log_prob(self, value):
Expand Down
12 changes: 5 additions & 7 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from numpyro.distributions import constraints
from numpyro.distributions.util import (
cumprod,
cumsum,
get_dtype,
matrix_to_tril_vec,
signed_stick_breaking_tril,
Expand Down Expand Up @@ -207,7 +205,7 @@ def __call__(self, x):

def inv(self, y):
# inverse stick-breaking
z1m_cumprod = 1 - cumsum(y * y)
z1m_cumprod = 1 - np.cumsum(y * y, axis=-1)
pad_width = [(0, 0)] * y.ndim
pad_width[-1] = (1, 0)
z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width,
Expand All @@ -224,7 +222,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
# flatten lower triangular part of `y`.

# stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted)
z1m_cumprod = 1 - cumsum(y * y)
z1m_cumprod = 1 - np.cumsum(y * y, axis=-1)
# by taking diagonal=-2, we don't need to shift z_cumprod to the right
# NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
Expand Down Expand Up @@ -384,7 +382,7 @@ class OrderedTransform(Transform):

def __call__(self, x):
z = np.concatenate([x[..., :1], np.exp(x[..., 1:])], axis=-1)
return cumsum(z)
return np.cumsum(z, axis=-1)

def inv(self, y):
x = np.log(y[..., 1:] - y[..., :-1])
Expand Down Expand Up @@ -457,7 +455,7 @@ def __call__(self, x):
x = x - np.log(x.shape[-1] - np.arange(x.shape[-1]))
# convert to probabilities (relative to the remaining) of each fraction of the stick
z = _clipped_expit(x)
z1m_cumprod = cumprod(1 - z)
z1m_cumprod = np.cumprod(1 - z, axis=-1)
pad_width = [(0, 0)] * x.ndim
pad_width[-1] = (0, 1)
z_padded = np.pad(z, pad_width, mode="constant", constant_values=1.)
Expand All @@ -468,7 +466,7 @@ def __call__(self, x):

def inv(self, y):
y_crop = y[..., :-1]
z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny)
z1m_cumprod = np.clip(1 - np.cumsum(y_crop, axis=-1), a_min=np.finfo(y.dtype).tiny)
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
x = np.log(y_crop / z1m_cumprod)
return x + np.log(x.shape[-1] - np.arange(x.shape[-1]))
Expand Down
24 changes: 3 additions & 21 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import update_wrapper
import math

from jax import custom_transforms, defjvp, jit, lax, random, vmap
from jax import jit, lax, random, vmap
from jax.dtypes import canonicalize_dtype
from jax.lib import xla_bridge
import jax.numpy as np
Expand Down Expand Up @@ -167,7 +167,7 @@ def _categorical(key, p, shape):
# this implementation is fast when event shape is small, and slow otherwise
# Ref: https://stackoverflow.com/a/34190035
shape = shape or p.shape[:-1]
s = cumsum(p)
s = np.cumsum(p, axis=-1)
r = random.uniform(key, shape=shape + (1,))
# FIXME: replace this computation by using binary search as suggested in the above
# reference. A while_loop + vmap for a reshaped 2D array would be enough.
Expand Down Expand Up @@ -312,24 +312,6 @@ def binary_cross_entropy_with_logits(x, y):
return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y


@custom_transforms
def cumsum(x):
return np.cumsum(x, axis=-1)


defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1))


@custom_transforms
def cumprod(x):
return np.cumprod(x, axis=-1)


# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
defjvp(cumprod, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)


def promote_shapes(*args, shape=()):
# adapted from lax.lax_numpy
if len(args) < 2 and not shape:
Expand Down Expand Up @@ -418,7 +400,7 @@ def signed_stick_breaking_tril(t):
# we omit the step of computing s = z * z_cumprod by using the fact:
# y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
z = r ** 2
z1m_cumprod = cumprod(1 - z)
z1m_cumprod = np.cumprod(1 - z, axis=-1)
z1m_cumprod_sqrt = np.sqrt(z1m_cumprod)

pad_width = [(0, 0)] * z.ndim
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
author_email='[email protected]',
install_requires=[
# TODO: pin to a specific version for the release (until JAX's API becomes stable)
'jax>=0.1.57',
'jaxlib>=0.1.37',
'jax @ git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax',
'jaxlib>=0.1.43',
'tqdm',
],
extras_require={
Expand Down
47 changes: 1 addition & 46 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
from numpy.testing import assert_allclose
import pytest

from jax import jacobian, lax, random, vmap
from jax import lax, random, vmap
import jax.numpy as np
from jax.scipy.special import expit, xlog1py, xlogy

from numpyro.distributions.util import (
binary_cross_entropy_with_logits,
categorical,
cholesky_update,
cumprod,
cumsum,
multinomial,
poisson,
vec_to_tril_matrix,
Expand All @@ -33,36 +31,6 @@ def test_binary_cross_entropy_with_logits(x, y):
assert_allclose(actual, expect, rtol=1e-6)


@pytest.mark.parametrize('shape', [
(3,),
(5, 3),
])
def test_cumsum_jac(shape):
rng_key = random.PRNGKey(0)
x = random.normal(rng_key, shape=shape)

def test_fn(x):
return np.stack([x[..., 0], x[..., 0] + x[..., 1], x[..., 0] + x[..., 1] + x[..., 2]], -1)

assert_allclose(cumsum(x), test_fn(x))
assert_allclose(jacobian(cumsum)(x), jacobian(test_fn)(x))


@pytest.mark.parametrize('shape', [
(3,),
(5, 3),
])
def test_cumprod_jac(shape):
rng_key = random.PRNGKey(0)
x = random.uniform(rng_key, shape=shape)

def test_fn(x):
return np.stack([x[..., 0], x[..., 0] * x[..., 1], x[..., 0] * x[..., 1] * x[..., 2]], -1)

assert_allclose(cumprod(x), test_fn(x))
assert_allclose(jacobian(cumprod)(x), jacobian(test_fn)(x), atol=1e-7)


@pytest.mark.parametrize('prim', [
xlogy,
xlog1py,
Expand All @@ -86,19 +54,6 @@ def test_binop_batch_rule(prim):
assert_allclose(actual_bx_y[i], prim(bx[i], y))


@pytest.mark.parametrize('prim', [
cumsum,
cumprod,
])
def test_unop_batch_rule(prim):
rng_key = random.PRNGKey(0)
bx = random.normal(rng_key, (3, 5))

actual = vmap(prim)(bx)
for i in range(3):
assert_allclose(actual[i], prim(bx[i]))


@pytest.mark.parametrize('p, shape', [
(np.array([0.1, 0.9]), ()),
(np.array([0.2, 0.8]), (2,)),
Expand Down