Skip to content

Commit 4ed38b0

Browse files
ricardoV94twiecki
authored andcommitted
Raise error when trying to jaxify graphs with RNGs
1 parent 9172925 commit 4ed38b0

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

pymc/sampling/jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from functools import partial
1919
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2020

21+
from pytensor.tensor.random.type import RandomType
22+
2123
from pymc.initial_point import StartDict
2224
from pymc.sampling.mcmc import _init_jitter
2325

@@ -81,6 +83,11 @@ def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariabl
8183

8284
shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)]
8385

86+
if any(isinstance(var.type, RandomType) for var in shared_variables):
87+
raise ValueError(
88+
"Graph contains shared RandomType variables which cannot be safely replaced"
89+
)
90+
8491
if any(var.default_update is not None for var in shared_variables):
8592
raise ValueError(
8693
"Graph contains shared variables with default_update which cannot "

pymc/tests/sampling/test_jax.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@
2929

3030
import pymc as pm
3131

32+
from pymc.sampling.jax import (
33+
_get_batched_jittered_initial_points,
34+
_get_log_likelihood,
35+
_numpyro_nuts_defaults,
36+
_replace_shared_variables,
37+
_update_numpyro_nuts_kwargs,
38+
get_jaxified_graph,
39+
get_jaxified_logp,
40+
sample_blackjax_nuts,
41+
sample_numpyro_nuts,
42+
)
43+
3244

3345
def test_old_import_route():
3446
import pymc.sampling.jax as new_sj
@@ -37,20 +49,6 @@ def test_old_import_route():
3749
assert set(new_sj.__all__) <= set(dir(old_sj))
3850

3951

40-
with pytest.warns(UserWarning, match="module is experimental"):
41-
from pymc.sampling.jax import (
42-
_get_batched_jittered_initial_points,
43-
_get_log_likelihood,
44-
_numpyro_nuts_defaults,
45-
_replace_shared_variables,
46-
_update_numpyro_nuts_kwargs,
47-
get_jaxified_graph,
48-
get_jaxified_logp,
49-
sample_blackjax_nuts,
50-
sample_numpyro_nuts,
51-
)
52-
53-
5452
@pytest.mark.parametrize(
5553
"sampler",
5654
[
@@ -182,6 +180,11 @@ def test_replace_shared_variables():
182180
with pytest.raises(ValueError, match="shared variables with default_update"):
183181
_replace_shared_variables([x])
184182

183+
shared_rng = pytensor.shared(np.random.default_rng(), name="shared_rng")
184+
x = pytensor.tensor.random.normal(rng=shared_rng)
185+
with pytest.raises(ValueError, match="Graph contains shared RandomType variables"):
186+
_replace_shared_variables([x])
187+
185188

186189
def test_get_jaxified_logp():
187190
with pm.Model() as m:

0 commit comments

Comments
 (0)