@@ -867,15 +867,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
867867 jax_fn = compile_random_function ([x_pt ], out )
868868 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
869869
870+ def test_random_scalar_shape_input (self ):
871+ dim0 = pt .scalar ("dim0" , dtype = int )
872+ dim1 = pt .scalar ("dim1" , dtype = int )
873+
874+ out = pt .random .normal (0 , 1 , size = dim0 )
875+ jax_fn = compile_random_function ([dim0 ], out )
876+ assert jax_fn (np .array (2 )).shape == (2 ,)
877+ assert jax_fn (np .array (3 )).shape == (3 ,)
878+
879+ out = pt .random .normal (0 , 1 , size = [dim0 , dim1 ])
880+ jax_fn = compile_random_function ([dim0 , dim1 ], out )
881+ assert jax_fn (np .array (2 ), np .array (3 )).shape == (2 , 3 )
882+ assert jax_fn (np .array (4 ), np .array (5 )).shape == (4 , 5 )
883+
870884 @pytest .mark .xfail (
871- reason = "`size_pt` should be specified as a static argument" , strict = True
885+ raises = TypeError , reason = "Cannot convert scalar input to integer"
872886 )
873- def test_random_concrete_shape_graph_input (self ):
874- rng = shared (np .random .default_rng (123 ))
875- size_pt = pt .scalar ()
876- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
877- jax_fn = compile_random_function ([size_pt ], out )
878- assert jax_fn (10 ).shape == (10 ,)
887+ def test_random_scalar_shape_input_not_supported (self ):
888+ dim = pt .scalar ("dim" , dtype = int )
889+ out1 = pt .random .normal (0 , 1 , size = dim )
890+ # An operation that wouldn't work if we replaced 0d array by integer
891+ out2 = dim [...].set (1 )
892+ jax_fn = compile_random_function ([dim ], [out1 , out2 ])
893+
894+ res1 , res2 = jax_fn (np .array (2 ))
895+ assert res1 .shape == (2 ,)
896+ assert res2 == 1
897+
898+ @pytest .mark .xfail (
899+ raises = TypeError , reason = "Cannot convert scalar input to integer"
900+ )
901+ def test_random_scalar_shape_input_not_supported2 (self ):
902+ dim = pt .scalar ("dim" , dtype = int )
903+ # This could theoretically be supported
904+ # but would require knowing that * 2 is a safe operation for a python integer
905+ out = pt .random .normal (0 , 1 , size = dim * 2 )
906+ jax_fn = compile_random_function ([dim ], out )
907+ assert jax_fn (np .array (2 )).shape == (4 ,)
908+
909+ @pytest .mark .xfail (
910+ raises = TypeError , reason = "Cannot convert tensor input to shape tuple"
911+ )
912+ def test_random_vector_shape_graph_input (self ):
913+ shape = pt .vector ("shape" , shape = (2 ,), dtype = int )
914+ out = pt .random .normal (0 , 1 , size = shape )
915+
916+ jax_fn = compile_random_function ([shape ], out )
917+ assert jax_fn (np .array ([2 , 3 ])).shape == (2 , 3 )
918+ assert jax_fn (np .array ([4 , 5 ])).shape == (4 , 5 )
879919
880920 def test_constant_shape_after_graph_rewriting (self ):
881921 size = pt .vector ("size" , shape = (2 ,), dtype = int )
0 commit comments