Skip to content

Commit b766706

Browse files
authored
Mark xfail for leaked tracer tests (#1997)
* Mark xfail for leaked tracer tests * test_handlers::test_plate is renamed to test_jit_trace * remove tracer leak xfail for the current passing tests * remove all global jax live arrays * add issues for each tracer leak test * fix failing tests * fix jax core deprecation in provenance
1 parent 3b7d7f0 commit b766706

22 files changed

+210
-92
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ jobs:
7878
- name: Test x64
7979
run: |
8080
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
81+
- name: Test tracer leak
82+
if: matrix.python-version == '3.10'
83+
env:
84+
JAX_CHECK_TRACER_LEAKS: 1
85+
run: |
86+
pytest -vs test/contrib/einstein/test_steinvi.py::test_run_smoke -k ASVGD
87+
pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke
88+
pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
89+
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
90+
pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
91+
pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
92+
pytest -vs test/infer/test_svi.py::test_mutable_state
93+
pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
94+
8195
- name: Coveralls
8296
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.10'
8397
uses: coverallsapp/github-action@v2

numpyro/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def entropy(self):
944944
logq = -jax.nn.softplus(self.logits)
945945
logp = -jax.nn.softplus(-self.logits)
946946
p = jax.scipy.special.expit(self.logits)
947-
p_clip = jnp.clip(p, min=jnp.finfo(p).tiny)
947+
p_clip = jnp.clip(p, jnp.finfo(p).tiny)
948948
return -(1 - p) * logq / p_clip - logp
949949

950950

numpyro/ops/provenance.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import jax
55
from jax.api_util import flatten_fun, shaped_abstractify
6-
import jax.core as core
76
from jax.experimental.pjit import pjit_p
87
import jax.util as util
98

@@ -12,11 +11,21 @@
1211
except ImportError:
1312
import jax.linear_util as lu
1413

14+
try:
15+
from jax.extend.core import Literal
16+
except ImportError:
17+
from jax.core import Literal
18+
1519
try:
1620
from jax.extend.core.primitives import call_p, closed_call_p
1721
except ImportError:
1822
from jax.core import call_p, closed_call_p
1923

24+
try:
25+
from jax.api_util import debug_info
26+
except ImportError:
27+
debug_info = None
28+
2029
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
2130
from jax.interpreters.pxla import xla_pmap_p
2231

@@ -44,14 +53,29 @@ def eval_provenance(fn, **kwargs):
4453
"""
4554
# Flatten the function and its arguments
4655
args, in_tree = jax.tree.flatten(((), kwargs))
47-
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree)
56+
fn_info = (
57+
dict(debug_info=debug_info("eval_provenance fn", fn, (), kwargs))
58+
if debug_info is not None
59+
else {}
60+
)
61+
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn, **fn_info), in_tree)
4862
# Abstract eval to get output pytree
4963
avals = util.safe_map(shaped_abstractify, args)
5064
# XXX: we split out the process of abstract evaluation and provenance tracking
5165
# for simplicity. In principle, they can be merged so that we only need to walk
5266
# through the equations once.
67+
68+
wrapped_info = (
69+
dict(
70+
debug_info=debug_info(
71+
"eval_provenance wrapped", wrapped_fun.call_wrapped, args, {}
72+
)
73+
)
74+
if debug_info is not None
75+
else {}
76+
)
5377
jaxpr, avals_out, _ = trace_to_jaxpr_dynamic(
54-
lu.wrap_init(wrapped_fun.call_wrapped, {}), avals
78+
lu.wrap_init(wrapped_fun.call_wrapped, {}, **wrapped_info), avals
5579
)
5680

5781
# get provenances of flatten kwargs
@@ -69,12 +93,12 @@ def track_deps_jaxpr(jaxpr, provenance_inputs):
6993
env = {}
7094

7195
def read(v):
72-
if isinstance(v, core.Literal):
96+
if isinstance(v, Literal):
7397
return frozenset()
7498
return env.get(v, frozenset())
7599

76100
def write(v, p):
77-
if isinstance(v, core.Literal):
101+
if isinstance(v, Literal):
78102
return
79103
env[v] = read(v) | p
80104

test/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@
33

44
import os
55

6+
import jax
67
from jax import config
78

89
from numpyro.util import set_rng_seed
910

1011
config.update("jax_platform_name", "cpu") # noqa: E702
1112

1213

14+
SETUP_STATE = {"is_first_test": True}
15+
16+
1317
def pytest_runtest_setup(item):
18+
if SETUP_STATE["is_first_test"]:
19+
SETUP_STATE["is_first_test"] = False
20+
assert len(jax.live_arrays()) == 0
21+
1422
if "JAX_ENABLE_X64" in os.environ:
1523
config.update("jax_enable_x64", True)
1624
set_rng_seed(0)

test/contrib/einstein/test_steinvi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from collections import namedtuple
55
from functools import partial
6+
import os
67
import string
78

89
import numpy as np
@@ -119,6 +120,9 @@ def model(features, labels):
119120
@pytest.mark.parametrize("kernel", KERNELS)
120121
@pytest.mark.parametrize("problem", (uniform_normal, regression))
121122
@pytest.mark.parametrize("method", ("ASVGD", "SVGD", "SteinVI"))
123+
@pytest.mark.xfail(
124+
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", reason="Expected tracer leak"
125+
)
122126
def test_run_smoke(kernel, problem, method):
123127
true_coefs, data, model = problem()
124128
if method == "ASVGD":

test/contrib/hsgp/test_laplacian.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,12 @@ def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]):
131131
(1, 2, False),
132132
([1, 1], 2, False),
133133
(np.array([1, 1])[..., None], 2, False),
134-
(jnp.array([1, 1])[..., None], 2, False),
134+
(np.array([1, 1])[..., None], 2, False),
135+
(np.array([1, 1]), 2, True),
135136
(np.array([1, 1]), 2, True),
136-
(jnp.array([1, 1]), 2, True),
137137
([1, 1], 1, True),
138138
(np.array([1, 1]), 1, True),
139-
(jnp.array([1, 1]), 1, True),
139+
(np.array([1, 1]), 1, True),
140140
],
141141
ids=[
142142
"ell-float",

test/contrib/stochastic_support/test_dcc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
@pytest.mark.parametrize(
2121
"branch_dist",
22-
[dist.Normal(0, 1), dist.Gamma(1, 1)],
22+
[lambda: dist.Normal(0, 1), lambda: dist.Gamma(1, 1)],
2323
)
2424
@pytest.mark.xfail(raises=RuntimeError)
2525
def test_continuous_branching(branch_dist):
2626
rng_key = random.PRNGKey(0)
2727

2828
def model():
29-
model1 = numpyro.sample("model1", branch_dist, infer={"branching": True})
29+
model1 = numpyro.sample("model1", branch_dist(), infer={"branching": True})
3030
mean = 1.0 if model1 == 0 else 2.0
3131
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
3232

test/contrib/test_enum_elbo.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,6 @@ def test_elbo_enumerate_plates_6(scale):
12471247

12481248
@config_enumerate
12491249
@handlers.scale(scale=scale)
1250-
@handlers.trace
12511250
def model_iplate_iplate(data, params):
12521251
probs_a = pyro.param(
12531252
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1305,7 +1304,6 @@ def model_iplate_plate(data, params):
13051304

13061305
@config_enumerate
13071306
@handlers.scale(scale=scale)
1308-
@handlers.trace
13091307
def model_plate_iplate(data, params):
13101308
probs_a = pyro.param(
13111309
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1423,7 +1421,6 @@ def test_elbo_enumerate_plates_7(scale):
14231421

14241422
@config_enumerate
14251423
@handlers.scale(scale=scale)
1426-
@handlers.trace
14271424
def model_iplate_iplate(data, params):
14281425
probs_a = pyro.param(
14291426
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1489,7 +1486,6 @@ def model_iplate_plate(data, params):
14891486

14901487
@config_enumerate
14911488
@handlers.scale(scale=scale)
1492-
@handlers.trace
14931489
def model_plate_iplate(data, params):
14941490
probs_a = pyro.param(
14951491
"probs_a", params["probs_a"], constraint=constraints.simplex

test/contrib/test_infer_discrete.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import logging
5+
import os
56

67
import numpy as np
78
from numpy.testing import assert_allclose
@@ -95,6 +96,10 @@ def hmm(data, hidden_dim=10):
9596
],
9697
)
9798
@pytest.mark.parametrize("temperature", [0, 1])
99+
@pytest.mark.xfail(
100+
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
101+
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/1998",
102+
)
98103
def test_scan_hmm_smoke(length, temperature):
99104
# This should match the example in the infer_discrete docstring.
100105
def hmm(data, hidden_dim=10):

test/contrib/test_tfp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def model(data):
179179
(
180180
"ReplicaExchangeMC",
181181
dict(
182-
inverse_temperatures=0.5 ** jnp.arange(4), make_kernel_fn=make_kernel_fn
182+
inverse_temperatures=0.5 ** np.arange(4), make_kernel_fn=make_kernel_fn
183183
),
184184
),
185185
],

0 commit comments

Comments
 (0)