File tree Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -800,6 +800,8 @@ class seed(Messenger):
800
800
>>> assert x == y
801
801
"""
802
802
803
+ stateful = False
804
+
803
805
def __init__ (
804
806
self ,
805
807
fn : Optional [Callable ] = None ,
@@ -835,6 +837,15 @@ def process_message(self, msg: Message) -> None:
835
837
self .rng_key , rng_key_sample = random .split (self .rng_key )
836
838
msg ["kwargs" ]["rng_key" ] = rng_key_sample
837
839
840
+ def __call__ (self , * args , ** kwargs ):
841
+ if self .fn is not None and not self .stateful :
842
+ cloned_seeded_fn = seed (
843
+ self .fn , rng_seed = self .rng_key , hide_types = self .hide_types
844
+ )
845
+ cloned_seeded_fn .stateful = True
846
+ return cloned_seeded_fn .__call__ (* args , ** kwargs )
847
+ return super ().__call__ (* args , ** kwargs )
848
+
838
849
839
850
class substitute (Messenger ):
840
851
"""
Original file line number Diff line number Diff line change @@ -340,9 +340,12 @@ def model_subsample_2():
340
340
model_subsample_2 ,
341
341
],
342
342
)
343
- def test_plate (model ):
343
+ def test_trace_jit (model ):
344
344
trace = handlers .trace (handlers .seed (model , random .PRNGKey (1 ))).get_trace ()
345
- jit_trace = handlers .trace (jit (handlers .seed (model , random .PRNGKey (1 )))).get_trace ()
345
+ with jax .check_tracer_leaks (False ):
346
+ jit_trace = handlers .trace (
347
+ jit (handlers .seed (model , random .PRNGKey (1 )))
348
+ ).get_trace ()
346
349
assert "z" in trace
347
350
for name , site in trace .items ():
348
351
if site ["type" ] == "sample" :
You can’t perform that action at this time.
0 commit comments