@@ -73,35 +73,41 @@ def test_symbols(self):
7373
7474class RandomTest (jtu .JaxTestCase ):
7575
76- def test_key_make_with_custom_impl (self ):
77- shape = ( 4 , 2 , 7 )
78-
76+ def make_custom_impl (self , shape , seed = False , split = False , fold_in = False ,
77+ random_bits = False ):
78+ assert not split and not fold_in and not random_bits # not yet implemented
7979 def seed_rule (_ ):
8080 return jnp .ones (shape , dtype = jnp .dtype ('uint32' ))
8181
8282 def no_rule (* args , ** kwargs ):
8383 assert False , 'unreachable'
8484
85- impl = jex .random .define_prng_impl (
86- key_shape = shape , seed = seed_rule , split = no_rule , fold_in = no_rule ,
87- random_bits = no_rule )
85+ return jex .random .define_prng_impl (
86+ key_shape = shape , seed = seed_rule if seed else no_rule , split = no_rule ,
87+ fold_in = no_rule , random_bits = no_rule )
88+
89+ def test_key_make_with_custom_impl (self ):
90+ impl = self .make_custom_impl (shape = (4 , 2 , 7 ), seed = True )
8891 k = jax .random .key (42 , impl = impl )
8992 self .assertEqual (k .shape , ())
9093 self .assertEqual (impl , jax .random .key_impl (k ))
9194
9295 def test_key_wrap_with_custom_impl (self ):
93- def no_rule (* args , ** kwargs ):
94- assert False , 'unreachable'
95-
9696 shape = (4 , 2 , 7 )
97- impl = jex .random .define_prng_impl (
98- key_shape = shape , seed = no_rule , split = no_rule , fold_in = no_rule ,
99- random_bits = no_rule )
97+ impl = self .make_custom_impl (shape = shape )
10098 data = jnp .ones ((3 , * shape ), dtype = jnp .dtype ('uint32' ))
10199 k = jax .random .wrap_key_data (data , impl = impl )
102100 self .assertEqual (k .shape , (3 ,))
103101 self .assertEqual (impl , jax .random .key_impl (k ))
104102
103+ def test_key_impl_is_spec (self ):
104+ # this is counterpart to random_test.py:
105+ # KeyArrayTest.test_key_impl_builtin_is_string_name
106+ spec_ref = self .make_custom_impl (shape = (4 , 2 , 7 ), seed = True )
107+ key = jax .random .key (42 , impl = spec_ref )
108+ spec = jax .random .key_impl (key )
109+ self .assertEqual (repr (spec ), f"PRNGSpec({ spec_ref ._impl .name !r} )" )
110+
105111
106112class FfiTest (jtu .JaxTestCase ):
107113
0 commit comments