Skip to content

Commit 109632f

Browse files
authored
fix tracer leak in loss function with mutable state (#2001)
1 parent b766706 commit 109632f

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ jobs:
3939
run: |
4040
make lint
4141
- name: Build documentation
42+
if: matrix.python-version != '3.9'
4243
run: |
4344
make docs
4445
- name: Test documentation
46+
if: matrix.python-version != '3.9'
4547
run: |
4648
make doctest
4749
python -m doctest -v README.md
@@ -89,7 +91,6 @@ jobs:
8991
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
9092
pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
9193
pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
92-
pytest -vs test/infer/test_svi.py::test_mutable_state
9394
pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
9495
9596
- name: Coveralls

numpyro/infer/svi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _make_loss_fn(
5151
def loss_fn(params):
5252
params = constrain_fn(params)
5353
if mutable_state is not None:
54-
params.update(mutable_state)
54+
params.update(jax.lax.stop_gradient(mutable_state))
5555
result = elbo.loss_with_mutable_state(
5656
rng_key, params, model, guide, *args, **kwargs, **static_kwargs
5757
)

test/infer/test_svi.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from functools import partial
5-
import os
65

76
import numpy as np
87
from numpy.testing import assert_allclose
@@ -520,9 +519,6 @@ def guide():
520519
@pytest.mark.parametrize("stable_update", [True, False])
521520
@pytest.mark.parametrize("num_particles", [1, 10])
522521
@pytest.mark.parametrize("elbo", [Trace_ELBO, TraceMeanField_ELBO])
523-
@pytest.mark.xfail(
524-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", reason="Expected tracer leak"
525-
)
526522
def test_mutable_state(stable_update, num_particles, elbo):
527523
def model():
528524
x = numpyro.sample("x", dist.Normal(-1, 1))

0 commit comments

Comments
 (0)