@@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
133
133
model = fast_model ()
134
134
population = {"x" : np .array ([2 , 3 , 4 ])}
135
135
blackjax_particles = blackjax_particles_from_pymc_population (model , population )
136
- jax .tree_map (np .testing .assert_allclose , blackjax_particles , [np .array ([[2 ], [3 ], [4 ]])])
136
+ jax .tree . map (np .testing .assert_allclose , blackjax_particles , [np .array ([[2 ], [3 ], [4 ]])])
137
137
138
138
139
139
def test_blackjax_particles_from_pymc_population_multivariate ():
@@ -144,7 +144,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():
144
144
145
145
population = {"x" : np .array ([0.34614613 , 1.09163261 , - 0.44526825 ]), "z" : np .array ([1 , 2 , 3 ])}
146
146
blackjax_particles = blackjax_particles_from_pymc_population (model , population )
147
- jax .tree_map (
147
+ jax .tree . map (
148
148
np .testing .assert_allclose ,
149
149
blackjax_particles ,
150
150
[np .array ([[0.34614613 ], [1.09163261 ], [- 0.44526825 ]]), np .array ([[1 ], [2 ], [3 ]])],
@@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
168
168
population = {"x" : np .array ([[2 , 3 ], [5 , 6 ], [7 , 9 ]]), "z" : np .array ([11 , 12 , 13 ])}
169
169
blackjax_particles = blackjax_particles_from_pymc_population (model , population )
170
170
171
- jax .tree_map (
171
+ jax .tree . map (
172
172
np .testing .assert_allclose ,
173
173
blackjax_particles ,
174
174
[np .array ([[2 , 3 ], [5 , 6 ], [7 , 9 ]]), np .array ([[11 ], [12 ], [13 ]])],
@@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
196
196
"""
197
197
logprior = get_jaxified_logprior (fast_model ())
198
198
for point in [- 0.5 , 0.0 , 0.5 ]:
199
- jax .tree_map (
199
+ jax .tree . map (
200
200
np .testing .assert_allclose ,
201
201
jax .vmap (logprior )([np .array ([point ])]),
202
202
np .log (scipy .stats .norm (0 , 1 ).pdf (point )),
@@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
212
212
"""
213
213
loglikelihood = get_jaxified_loglikelihood (fast_model ())
214
214
for point in [- 0.5 , 0.0 , 0.5 ]:
215
- jax .tree_map (
215
+ jax .tree . map (
216
216
np .testing .assert_allclose ,
217
217
jax .vmap (loglikelihood )([np .array ([point ])]),
218
218
np .log (scipy .stats .norm (point , 1 ).pdf (0 )),
0 commit comments