Skip to content

Commit 9772c94

Browse files
authored
update jax to master 0.1.63 (#565)
* update jax to master 0.1.53 * remove custom cumsum, cumprod * use stop_gradient for logsumexp
1 parent 017ac03 commit 9772c94

File tree

8 files changed

+20
-87
lines changed

8 files changed

+20
-87
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ install:
1717
# Keep track of pyro-api master branch
1818
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
1919
- pip install jaxlib
20-
- pip install jax
20+
- pip install -e git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax
2121
- pip install .[examples,test]
2222
- pip freeze
2323

examples/neutra.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
from numpyro.infer.util import initialize_model, transformed_potential_energy
3737

3838

39-
# XXX: upstream logsumexp throws NaN under fast-math mode + MCMC's progress_bar=True
4039
def logsumexp(x, axis=0):
41-
return np.log(np.sum(np.exp(x), axis=axis))
40+
# TODO: remove when https://github.com/google/jax/pull/2260 merged upstream
41+
x_max = lax.stop_gradient(np.max(x, axis=axis, keepdims=True))
42+
return np.log(np.sum(np.exp(x - x_max), axis=axis)) + x_max.squeeze(axis=axis)
4243

4344

4445
class DualMoonDistribution(dist.Distribution):

examples/ode.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import matplotlib
2222
import matplotlib.pyplot as plt
2323

24-
from jax.experimental.ode import build_odeint
24+
from jax.experimental.ode import odeint
2525
import jax.numpy as np
2626
from jax.random import PRNGKey
2727

@@ -33,21 +33,19 @@
3333
matplotlib.use('Agg') # noqa: E402
3434

3535

36-
def dz_dt(z, t, alpha, beta, gamma, delta):
36+
def dz_dt(z, t, theta):
3737
"""
3838
Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
3939
describes the interaction of two species.
4040
"""
4141
u = z[0]
4242
v = z[1]
43+
alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3]
4344
du_dt = (alpha - beta * v) * u
4445
dv_dt = (-gamma + delta * u) * v
4546
return np.stack([du_dt, dv_dt])
4647

4748

48-
predator_prey_int = build_odeint(dz_dt, rtol=1e-5, atol=1e-3, mxstep=500)
49-
50-
5149
def model(N, y=None):
5250
"""
5351
:param int N: number of measurement times
@@ -63,7 +61,7 @@ def model(N, y=None):
6361
dist.TruncatedNormal(low=0., loc=np.array([0.5, 0.05, 1.5, 0.05]),
6462
scale=np.array([0.5, 0.05, 0.5, 0.05])))
6563
# integrate dz/dt, the result will have shape N x 2
66-
z = predator_prey_int(z_init, ts, *theta)
64+
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
6765
# measurement errors, we expect that measured hare has larger error than measured lynx
6866
sigma = numpyro.sample("sigma", dist.Exponential(np.array([1, 2])))
6967
# measured populations (in log scale)

numpyro/distributions/continuous.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from numpyro.distributions.transforms import AffineTransform, ExpTransform, InvCholeskyTransform, PowerTransform
3838
from numpyro.distributions.util import (
3939
cholesky_of_inverse,
40-
cumsum,
4140
lazy_property,
4241
matrix_to_tril_vec,
4342
promote_shapes,
@@ -227,7 +226,7 @@ def __init__(self, scale=1., num_steps=1, validate_args=None):
227226
def sample(self, key, sample_shape=()):
228227
shape = sample_shape + self.batch_shape + self.event_shape
229228
walks = random.normal(key, shape=shape)
230-
return cumsum(walks) * np.expand_dims(self.scale, axis=-1)
229+
return np.cumsum(walks, axis=-1) * np.expand_dims(self.scale, axis=-1)
231230

232231
@validate_sample
233232
def log_prob(self, value):

numpyro/distributions/transforms.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
from numpyro.distributions import constraints
1515
from numpyro.distributions.util import (
16-
cumprod,
17-
cumsum,
1816
get_dtype,
1917
matrix_to_tril_vec,
2018
signed_stick_breaking_tril,
@@ -207,7 +205,7 @@ def __call__(self, x):
207205

208206
def inv(self, y):
209207
# inverse stick-breaking
210-
z1m_cumprod = 1 - cumsum(y * y)
208+
z1m_cumprod = 1 - np.cumsum(y * y, axis=-1)
211209
pad_width = [(0, 0)] * y.ndim
212210
pad_width[-1] = (1, 0)
213211
z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width,
@@ -224,7 +222,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
224222
# flatten lower triangular part of `y`.
225223

226224
# stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted)
227-
z1m_cumprod = 1 - cumsum(y * y)
225+
z1m_cumprod = 1 - np.cumsum(y * y, axis=-1)
228226
# by taking diagonal=-2, we don't need to shift z_cumprod to the right
229227
# NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
230228
z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
@@ -384,7 +382,7 @@ class OrderedTransform(Transform):
384382

385383
def __call__(self, x):
386384
z = np.concatenate([x[..., :1], np.exp(x[..., 1:])], axis=-1)
387-
return cumsum(z)
385+
return np.cumsum(z, axis=-1)
388386

389387
def inv(self, y):
390388
x = np.log(y[..., 1:] - y[..., :-1])
@@ -457,7 +455,7 @@ def __call__(self, x):
457455
x = x - np.log(x.shape[-1] - np.arange(x.shape[-1]))
458456
# convert to probabilities (relative to the remaining) of each fraction of the stick
459457
z = _clipped_expit(x)
460-
z1m_cumprod = cumprod(1 - z)
458+
z1m_cumprod = np.cumprod(1 - z, axis=-1)
461459
pad_width = [(0, 0)] * x.ndim
462460
pad_width[-1] = (0, 1)
463461
z_padded = np.pad(z, pad_width, mode="constant", constant_values=1.)
@@ -468,7 +466,7 @@ def __call__(self, x):
468466

469467
def inv(self, y):
470468
y_crop = y[..., :-1]
471-
z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny)
469+
z1m_cumprod = np.clip(1 - np.cumsum(y_crop, axis=-1), a_min=np.finfo(y.dtype).tiny)
472470
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
473471
x = np.log(y_crop / z1m_cumprod)
474472
return x + np.log(x.shape[-1] - np.arange(x.shape[-1]))

numpyro/distributions/util.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import update_wrapper
55
import math
66

7-
from jax import custom_transforms, defjvp, jit, lax, random, vmap
7+
from jax import jit, lax, random, vmap
88
from jax.dtypes import canonicalize_dtype
99
from jax.lib import xla_bridge
1010
import jax.numpy as np
@@ -167,7 +167,7 @@ def _categorical(key, p, shape):
167167
# this implementation is fast when event shape is small, and slow otherwise
168168
# Ref: https://stackoverflow.com/a/34190035
169169
shape = shape or p.shape[:-1]
170-
s = cumsum(p)
170+
s = np.cumsum(p, axis=-1)
171171
r = random.uniform(key, shape=shape + (1,))
172172
# FIXME: replace this computation by using binary search as suggested in the above
173173
# reference. A while_loop + vmap for a reshaped 2D array would be enough.
@@ -312,24 +312,6 @@ def binary_cross_entropy_with_logits(x, y):
312312
return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y
313313

314314

315-
@custom_transforms
316-
def cumsum(x):
317-
return np.cumsum(x, axis=-1)
318-
319-
320-
defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1))
321-
322-
323-
@custom_transforms
324-
def cumprod(x):
325-
return np.cumprod(x, axis=-1)
326-
327-
328-
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
329-
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
330-
defjvp(cumprod, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
331-
332-
333315
def promote_shapes(*args, shape=()):
334316
# adapted from lax.lax_numpy
335317
if len(args) < 2 and not shape:
@@ -418,7 +400,7 @@ def signed_stick_breaking_tril(t):
418400
# we omit the step of computing s = z * z_cumprod by using the fact:
419401
# y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
420402
z = r ** 2
421-
z1m_cumprod = cumprod(1 - z)
403+
z1m_cumprod = np.cumprod(1 - z, axis=-1)
422404
z1m_cumprod_sqrt = np.sqrt(z1m_cumprod)
423405

424406
pad_width = [(0, 0)] * z.ndim

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
author_email='[email protected]',
3535
install_requires=[
3636
# TODO: pin to a specific version for the release (until JAX's API becomes stable)
37-
'jax>=0.1.57',
38-
'jaxlib>=0.1.37',
37+
'jax @ git+https://github.com/google/jax.git@2512ec6ebebf1b26d2aefbe618b4c147c251d194#egg=jax',
38+
'jaxlib>=0.1.43',
3939
'tqdm',
4040
],
4141
extras_require={

test/test_distributions_util.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
from numpy.testing import assert_allclose
88
import pytest
99

10-
from jax import jacobian, lax, random, vmap
10+
from jax import lax, random, vmap
1111
import jax.numpy as np
1212
from jax.scipy.special import expit, xlog1py, xlogy
1313

1414
from numpyro.distributions.util import (
1515
binary_cross_entropy_with_logits,
1616
categorical,
1717
cholesky_update,
18-
cumprod,
19-
cumsum,
2018
multinomial,
2119
poisson,
2220
vec_to_tril_matrix,
@@ -33,36 +31,6 @@ def test_binary_cross_entropy_with_logits(x, y):
3331
assert_allclose(actual, expect, rtol=1e-6)
3432

3533

36-
@pytest.mark.parametrize('shape', [
37-
(3,),
38-
(5, 3),
39-
])
40-
def test_cumsum_jac(shape):
41-
rng_key = random.PRNGKey(0)
42-
x = random.normal(rng_key, shape=shape)
43-
44-
def test_fn(x):
45-
return np.stack([x[..., 0], x[..., 0] + x[..., 1], x[..., 0] + x[..., 1] + x[..., 2]], -1)
46-
47-
assert_allclose(cumsum(x), test_fn(x))
48-
assert_allclose(jacobian(cumsum)(x), jacobian(test_fn)(x))
49-
50-
51-
@pytest.mark.parametrize('shape', [
52-
(3,),
53-
(5, 3),
54-
])
55-
def test_cumprod_jac(shape):
56-
rng_key = random.PRNGKey(0)
57-
x = random.uniform(rng_key, shape=shape)
58-
59-
def test_fn(x):
60-
return np.stack([x[..., 0], x[..., 0] * x[..., 1], x[..., 0] * x[..., 1] * x[..., 2]], -1)
61-
62-
assert_allclose(cumprod(x), test_fn(x))
63-
assert_allclose(jacobian(cumprod)(x), jacobian(test_fn)(x), atol=1e-7)
64-
65-
6634
@pytest.mark.parametrize('prim', [
6735
xlogy,
6836
xlog1py,
@@ -86,19 +54,6 @@ def test_binop_batch_rule(prim):
8654
assert_allclose(actual_bx_y[i], prim(bx[i], y))
8755

8856

89-
@pytest.mark.parametrize('prim', [
90-
cumsum,
91-
cumprod,
92-
])
93-
def test_unop_batch_rule(prim):
94-
rng_key = random.PRNGKey(0)
95-
bx = random.normal(rng_key, (3, 5))
96-
97-
actual = vmap(prim)(bx)
98-
for i in range(3):
99-
assert_allclose(actual[i], prim(bx[i]))
100-
101-
10257
@pytest.mark.parametrize('p, shape', [
10358
(np.array([0.1, 0.9]), ()),
10459
(np.array([0.2, 0.8]), (2,)),

0 commit comments

Comments
 (0)