Skip to content

Commit 774ddff

Browse files
jessegrabowskiricardoV94
authored andcommitted
jax.tree_map -> jax.tree.map
Also ignore warning from upstream dependency
1 parent dc05dcc commit 774ddff

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

pymc_experimental/tests/test_blackjax_smc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
133133
model = fast_model()
134134
population = {"x": np.array([2, 3, 4])}
135135
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
136-
jax.tree_map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
136+
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
137137

138138

139139
def test_blackjax_particles_from_pymc_population_multivariate():
@@ -144,7 +144,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():
144144

145145
population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])}
146146
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
147-
jax.tree_map(
147+
jax.tree.map(
148148
np.testing.assert_allclose,
149149
blackjax_particles,
150150
[np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
@@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
168168
population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
169169
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
170170

171-
jax.tree_map(
171+
jax.tree.map(
172172
np.testing.assert_allclose,
173173
blackjax_particles,
174174
[np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
@@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
196196
"""
197197
logprior = get_jaxified_logprior(fast_model())
198198
for point in [-0.5, 0.0, 0.5]:
199-
jax.tree_map(
199+
jax.tree.map(
200200
np.testing.assert_allclose,
201201
jax.vmap(logprior)([np.array([point])]),
202202
np.log(scipy.stats.norm(0, 1).pdf(point)),
@@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
212212
"""
213213
loglikelihood = get_jaxified_loglikelihood(fast_model())
214214
for point in [-0.5, 0.0, 0.5]:
215-
jax.tree_map(
215+
jax.tree.map(
216216
np.testing.assert_allclose,
217217
jax.vmap(loglikelihood)([np.array([point])]),
218218
np.log(scipy.stats.norm(point, 1).pdf(0)),

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ filterwarnings =[
1717

1818
# JAX issues an over-eager warning if os.fork() is called when the JAX module is loaded, even if JAX isn't being used
1919
'ignore:os\.fork\(\) was called\.:RuntimeWarning',
20+
21+
# Warning coming from blackjax
22+
'ignore:jax\.tree_map is deprecated:DeprecationWarning',
2023
]
2124

2225
[tool.black]

0 commit comments

Comments
 (0)