@@ -809,94 +809,90 @@ def sample_fn(rng, size, dtype, *parameters):
809809 compare_jax_and_py (fgraph , [])
810810
811811
812- def test_random_concrete_shape ():
813- """JAX should compile when a `RandomVariable` is passed a concrete shape.
814-
815- There are three quantities that JAX considers as concrete:
816- 1. Constants known at compile time;
817- 2. The shape of an array.
818- 3. `static_argnums` parameters
819- This test makes sure that graphs with `RandomVariable`s compile when the
820- `size` parameter satisfies either of these criteria.
821-
822- """
823- rng = shared (np .random .default_rng (123 ))
824- x_pt = pt .dmatrix ()
825- out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
826- jax_fn = compile_random_function ([x_pt ], out )
827- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
828-
829-
830- def test_random_concrete_shape_from_param ():
831- rng = shared (np .random .default_rng (123 ))
832- x_pt = pt .dmatrix ()
833- out = pt .random .normal (x_pt , 1 , rng = rng )
834- jax_fn = compile_random_function ([x_pt ], out )
835- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836-
837-
838- def test_random_concrete_shape_subtensor ():
839- """JAX should compile when a concrete value is passed for the `size` parameter.
840-
841- This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
842- inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
843- inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
844- rewrite.
845-
846- JAX does not accept scalars as `size` or `shape` arguments, so this is a
847- slight improvement over their API.
848-
849- """
850- rng = shared (np .random .default_rng (123 ))
851- x_pt = pt .dmatrix ()
852- out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
853- jax_fn = compile_random_function ([x_pt ], out )
854- assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
855-
856-
857- def test_random_concrete_shape_subtensor_tuple ():
858- """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
859-
860- This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
861- inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
862- scalar inputs into tuples of concrete values using the
863- `jax_size_parameter_as_tuple` rewrite.
864-
865- """
866- rng = shared (np .random .default_rng (123 ))
867- x_pt = pt .dmatrix ()
868- out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
869- jax_fn = compile_random_function ([x_pt ], out )
870- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
871-
872-
873- @pytest .mark .xfail (
874- reason = "`size_pt` should be specified as a static argument" , strict = True
875- )
876- def test_random_concrete_shape_graph_input ():
877- rng = shared (np .random .default_rng (123 ))
878- size_pt = pt .scalar ()
879- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
880- jax_fn = compile_random_function ([size_pt ], out )
881- assert jax_fn (10 ).shape == (10 ,)
882-
883-
884- def test_constant_shape_after_graph_rewriting ():
885- size = pt .vector ("size" , shape = (2 ,), dtype = int )
886- x = pt .random .normal (size = size )
887- assert x .type .shape == (None , None )
888-
889- with pytest .raises (TypeError ):
890- compile_random_function ([size ], x )([2 , 5 ])
891-
892- # Rebuild with strict=False so output type is not updated
893- # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
894- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
895- assert new_x .type .shape == (None , None )
896- assert compile_random_function ([], new_x )().shape == (2 , 5 )
897-
898- # Rebuild with strict=True, so output type is updated
899- # This uses a different path in the dispatch implementation
900- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
901- assert new_x .type .shape == (2 , 5 )
902- assert compile_random_function ([], new_x )().shape == (2 , 5 )
812+ class TestRandomShapeInputs :
813+ def test_random_concrete_shape (self ):
814+ """JAX should compile when a `RandomVariable` is passed a concrete shape.
815+
816+ There are three quantities that JAX considers as concrete:
817+ 1. Constants known at compile time;
818+ 2. The shape of an array.
819+ 3. `static_argnums` parameters
820+ This test makes sure that graphs with `RandomVariable`s compile when the
821+ `size` parameter satisfies either of these criteria.
822+
823+ """
824+ rng = shared (np .random .default_rng (123 ))
825+ x_pt = pt .dmatrix ()
826+ out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
827+ jax_fn = compile_random_function ([x_pt ], out )
828+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
829+
830+ def test_random_concrete_shape_from_param (self ):
831+ rng = shared (np .random .default_rng (123 ))
832+ x_pt = pt .dmatrix ()
833+ out = pt .random .normal (x_pt , 1 , rng = rng )
834+ jax_fn = compile_random_function ([x_pt ], out )
835+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836+
837+ def test_random_concrete_shape_subtensor (self ):
838+ """JAX should compile when a concrete value is passed for the `size` parameter.
839+
840+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
841+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
842+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
843+ rewrite.
844+
845+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
846+ slight improvement over their API.
847+
848+ """
849+ rng = shared (np .random .default_rng (123 ))
850+ x_pt = pt .dmatrix ()
851+ out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
852+ jax_fn = compile_random_function ([x_pt ], out )
853+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
854+
855+ def test_random_concrete_shape_subtensor_tuple (self ):
856+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
857+
858+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
859+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
860+ scalar inputs into tuples of concrete values using the
861+ `jax_size_parameter_as_tuple` rewrite.
862+
863+ """
864+ rng = shared (np .random .default_rng (123 ))
865+ x_pt = pt .dmatrix ()
866+ out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
867+ jax_fn = compile_random_function ([x_pt ], out )
868+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
869+
870+ @pytest .mark .xfail (
871+ reason = "`size_pt` should be specified as a static argument" , strict = True
872+ )
873+ def test_random_concrete_shape_graph_input (self ):
874+ rng = shared (np .random .default_rng (123 ))
875+ size_pt = pt .scalar ()
876+ out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
877+ jax_fn = compile_random_function ([size_pt ], out )
878+ assert jax_fn (10 ).shape == (10 ,)
879+
880+ def test_constant_shape_after_graph_rewriting (self ):
881+ size = pt .vector ("size" , shape = (2 ,), dtype = int )
882+ x = pt .random .normal (size = size )
883+ assert x .type .shape == (None , None )
884+
885+ with pytest .raises (TypeError ):
886+ compile_random_function ([size ], x )([2 , 5 ])
887+
888+ # Rebuild with strict=False so output type is not updated
889+ # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
890+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
891+ assert new_x .type .shape == (None , None )
892+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
893+
894+ # Rebuild with strict=True, so output type is updated
895+ # This uses a different path in the dispatch implementation
896+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
897+ assert new_x .type .shape == (2 , 5 )
898+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
0 commit comments