29
29
30
30
import pymc as pm
31
31
32
+ from pymc .sampling .jax import (
33
+ _get_batched_jittered_initial_points ,
34
+ _get_log_likelihood ,
35
+ _numpyro_nuts_defaults ,
36
+ _replace_shared_variables ,
37
+ _update_numpyro_nuts_kwargs ,
38
+ get_jaxified_graph ,
39
+ get_jaxified_logp ,
40
+ sample_blackjax_nuts ,
41
+ sample_numpyro_nuts ,
42
+ )
43
+
32
44
33
45
def test_old_import_route ():
34
46
import pymc .sampling .jax as new_sj
@@ -37,20 +49,6 @@ def test_old_import_route():
37
49
assert set (new_sj .__all__ ) <= set (dir (old_sj ))
38
50
39
51
40
- with pytest .warns (UserWarning , match = "module is experimental" ):
41
- from pymc .sampling .jax import (
42
- _get_batched_jittered_initial_points ,
43
- _get_log_likelihood ,
44
- _numpyro_nuts_defaults ,
45
- _replace_shared_variables ,
46
- _update_numpyro_nuts_kwargs ,
47
- get_jaxified_graph ,
48
- get_jaxified_logp ,
49
- sample_blackjax_nuts ,
50
- sample_numpyro_nuts ,
51
- )
52
-
53
-
54
52
@pytest .mark .parametrize (
55
53
"sampler" ,
56
54
[
@@ -182,6 +180,11 @@ def test_replace_shared_variables():
182
180
with pytest .raises (ValueError , match = "shared variables with default_update" ):
183
181
_replace_shared_variables ([x ])
184
182
183
+ shared_rng = pytensor .shared (np .random .default_rng (), name = "shared_rng" )
184
+ x = pytensor .tensor .random .normal (rng = shared_rng )
185
+ with pytest .raises (ValueError , match = "Graph contains shared RandomType variables" ):
186
+ _replace_shared_variables ([x ])
187
+
185
188
186
189
def test_get_jaxified_logp ():
187
190
with pm .Model () as m :
0 commit comments