Skip to content

Commit 7c597a8

Browse files
committed
copy_function_with_new_rngs warns with JAXLinker
1 parent 88308f2 commit 7c597a8

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

pymc/pytensorf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pytensor.graph.fg import FunctionGraph, Output
4040
from pytensor.graph.op import Op
41+
from pytensor.link.jax.linker import JAXLinker
4142
from pytensor.scalar.basic import Cast
4243
from pytensor.scan.op import Scan
4344
from pytensor.tensor.basic import _as_tensor_variable
@@ -1208,6 +1209,15 @@ def copy_function_with_new_rngs(
12081209
fn_ = fn.f if isinstance(fn, PointFunc) else fn
12091210
shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
12101211
n_shared_rngs = len(shared_rngs)
1212+
if n_shared_rngs > 0 and isinstance(fn_.maker.linker, JAXLinker):
1213+
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
1214+
# used internally are not the ones that `function.get_shared()` returns.
1215+
warnings.warn(
1216+
"At the moment, it is not possible to set the random generator's key for "
1217+
"JAX linked functions. This means that the draws yielded by the random "
1218+
"variables that are requested by 'Deterministic' will not be reproducible."
1219+
)
1220+
return fn
12111221
swap = {
12121222
old_shared_rng: shared(rng, borrow=True)
12131223
for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)

tests/sampling/test_mcmc.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -929,12 +929,29 @@ def trace_backend(request):
929929
return trace
930930

931931

932-
def test_random_deterministics(trace_backend):
932+
@pytest.fixture(scope="function", params=["FAST_COMPILE", "NUMBA", "JAX"])
933+
def pytensor_mode(request):
934+
return request.param
935+
936+
937+
def test_random_deterministics(trace_backend, pytensor_mode):
933938
with pm.Model() as m:
934939
x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero
935940
pm.Deterministic("y", x + pm.Normal.dist())
936941

937-
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
938-
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
939-
940-
assert idata1.posterior.equals(idata2.posterior)
942+
if pytensor_mode == "JAX":
943+
expected_warning = (
944+
"At the moment, it is not possible to set the random generator's key for "
945+
"JAX linked functions. This means that the draws yielded by the random "
946+
"variables that are requested by 'Deterministic' will not be reproducible."
947+
)
948+
with pytest.warns(UserWarning, match=expected_warning):
949+
with pytensor.config.change_flags(mode=pytensor_mode):
950+
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
951+
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
952+
assert not idata1.posterior.equals(idata2.posterior)
953+
else:
954+
with pytensor.config.change_flags(mode=pytensor_mode):
955+
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
956+
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
957+
assert idata1.posterior.equals(idata2.posterior)

0 commit comments

Comments
 (0)