Skip to content

Commit 627d19a

Browse files
authored
Make seed handler stateless by default (#1983)
1 parent f7746d5 commit 627d19a

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

numpyro/handlers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ class seed(Messenger):
800800
>>> assert x == y
801801
"""
802802

803+
stateful = False
804+
803805
def __init__(
804806
self,
805807
fn: Optional[Callable] = None,
@@ -835,6 +837,15 @@ def process_message(self, msg: Message) -> None:
835837
self.rng_key, rng_key_sample = random.split(self.rng_key)
836838
msg["kwargs"]["rng_key"] = rng_key_sample
837839

840+
def __call__(self, *args, **kwargs):
841+
if self.fn is not None and not self.stateful:
842+
cloned_seeded_fn = seed(
843+
self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
844+
)
845+
cloned_seeded_fn.stateful = True
846+
return cloned_seeded_fn.__call__(*args, **kwargs)
847+
return super().__call__(*args, **kwargs)
848+
838849

839850
class substitute(Messenger):
840851
"""

test/test_handlers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,12 @@ def model_subsample_2():
340340
model_subsample_2,
341341
],
342342
)
343-
def test_plate(model):
343+
def test_trace_jit(model):
344344
trace = handlers.trace(handlers.seed(model, random.PRNGKey(1))).get_trace()
345-
jit_trace = handlers.trace(jit(handlers.seed(model, random.PRNGKey(1)))).get_trace()
345+
with jax.check_tracer_leaks(False):
346+
jit_trace = handlers.trace(
347+
jit(handlers.seed(model, random.PRNGKey(1)))
348+
).get_trace()
346349
assert "z" in trace
347350
for name, site in trace.items():
348351
if site["type"] == "sample":

0 commit comments

Comments
 (0)