Skip to content

Commit fe01f02

Browse files
authored
Fix jax.tree_util future warnings (#1464)
* Fix jax.tree_util future warnings * fix docker link * Filter out warnings in contrib modules * fix lint * bypass isort in cvae * Fix lint
1 parent 47f56c9 commit fe01f02

File tree

9 files changed

+21
-10
lines changed

9 files changed

+21
-10
lines changed

docker/release/Dockerfile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
77

88
# declare the image name
99
# note that this image uses Python 3.8
10-
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
11-
JAXLIB_CUDA=111
10+
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04
1211

1312
# install python3 and pip on top of the base Ubuntu image
1413
RUN apt update && \
@@ -21,4 +20,4 @@ ENV PATH=/root/.local/bin:$PATH
2120
RUN pip3 install --user \
2221
# we pull wheels from google's api as per https://github.com/google/jax#installation
2322
# the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source)
24-
numpyro[cuda${JAXLIB_CUDA}] -f https://storage.googleapis.com/jax-releases/jax_releases.html
23+
numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

examples/cvae-flax/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from data import load_dataset
77
import matplotlib.pyplot as plt
8-
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model
98
from train_baseline import train_baseline
109
from train_cvae import train_cvae
1110

1211
from numpyro.examples.datasets import MNIST
1312

13+
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model # isort:skip
14+
1415

1516
def main(args):
1617
train_init, train_fetch = load_dataset(

examples/cvae-flax/train_baseline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from models import cross_entropy_loss
5-
64
from flax.training.train_state import TrainState
75
import jax
86
from jax import lax, numpy as jnp, random
97
import optax
108

9+
from models import cross_entropy_loss # isort:skip
10+
1111

1212
def create_train_state(model, x, learning_rate_fn):
1313
params = model.init(random.PRNGKey(0), x)

numpyro/contrib/control_flow/scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from collections import OrderedDict
55
from functools import partial
66

7-
from jax import device_put, lax, random, tree_flatten, tree_map, tree_unflatten
7+
from jax import device_put, lax, random
88
import jax.numpy as jnp
9+
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
910

1011
from numpyro import handlers
1112
from numpyro.ops.pytree import PytreeTrace

numpyro/contrib/tfp/mcmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from collections import namedtuple
66
import inspect
77

8-
from jax import random, tree_map, vmap
8+
from jax import random, vmap
99
from jax.flatten_util import ravel_pytree
1010
import jax.numpy as jnp
11+
from jax.tree_util import tree_map
1112
import tensorflow_probability.substrates.jax as tfp
1213

1314
from numpyro.infer import init_to_uniform

numpyro/infer/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import numpy as np
1111

1212
import jax
13-
from jax import device_get, jacfwd, lax, random, tree_flatten, value_and_grad
13+
from jax import device_get, jacfwd, lax, random, value_and_grad
1414
from jax.flatten_util import ravel_pytree
1515
from jax.lax import broadcast_shapes
1616
import jax.numpy as jnp
17-
from jax.tree_util import tree_map
17+
from jax.tree_util import tree_flatten, tree_map
1818

1919
import numpyro
2020
from numpyro.distributions import constraints

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ skip=docs
2222
filterwarnings = error
2323
ignore:numpy.ufunc size changed,:RuntimeWarning
2424
ignore:Using a non-tuple sequence:FutureWarning
25+
ignore:jax.tree_structure is deprecated:FutureWarning
2526
ignore:numpy.linalg support is experimental:UserWarning
2627
ignore:scipy.linalg support is experimental:UserWarning
2728
once:No GPU:UserWarning

test/contrib/test_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
import numpyro.distributions as dist
2424
from numpyro.infer import MCMC, NUTS
2525

26+
pytestmark = pytest.mark.filterwarnings(
27+
"ignore:jax.tree_.+ is deprecated:FutureWarning"
28+
)
29+
2630

2731
def haiku_model_by_shape(x, y):
2832
import haiku as hk

test/contrib/test_nested_sampling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
import numpyro.distributions as dist
1414
from numpyro.distributions.transforms import AffineTransform, ExpTransform
1515

16+
pytestmark = pytest.mark.filterwarnings(
17+
"ignore:jax.tree_.+ is deprecated:FutureWarning"
18+
)
19+
1620

1721
# Test helper to extract a few central moments from samples.
1822
def get_moments(x):

0 commit comments

Comments
 (0)