@@ -449,14 +449,33 @@ def test_random_concrete_shape():
449
449
"""
450
450
rng = shared (np .random .RandomState (123 ))
451
451
x_at = at .dmatrix ()
452
- f = at .random .normal (0 , 1 , size = (3 ,), rng = rng )
453
- g = at .random .normal (f , 1 , size = x_at .shape , rng = rng )
454
- g_fn = function ([x_at ], g , mode = jax_mode )
455
- _ = g_fn (np .ones ((2 , 3 )))
452
+ out = at .random .normal (0 , 1 , size = x_at .shape , rng = rng )
453
+ jax_fn = function ([x_at ], out , mode = jax_mode )
454
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
456
455
457
- # This should compile, and `size_at` be passed to the list of `static_argnums`.
458
- with pytest .raises (NotImplementedError ):
459
- size_at = at .scalar ()
460
- g = at .random .normal (f , 1 , size = size_at , rng = rng )
461
- g_fn = function ([size_at ], g , mode = jax_mode )
462
- _ = g_fn (10 )
456
+
457
+ @pytest .mark .xfail (reason = "size argument specified as a tuple is a `DimShuffle` node" )
458
+ def test_random_concrete_shape_subtensor ():
459
+ rng = shared (np .random .RandomState (123 ))
460
+ x_at = at .dmatrix ()
461
+ out = at .random .normal (0 , 1 , size = x_at .shape [1 ], rng = rng )
462
+ jax_fn = function ([x_at ], out , mode = jax_mode )
463
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
464
+
465
+
466
+ @pytest .mark .xfail (reason = "size argument specified as a tuple is a `MakeVector` node" )
467
+ def test_random_concrete_shape_subtensor_tuple ():
468
+ rng = shared (np .random .RandomState (123 ))
469
+ x_at = at .dmatrix ()
470
+ out = at .random .normal (0 , 1 , size = (x_at .shape [0 ],), rng = rng )
471
+ jax_fn = function ([x_at ], out , mode = jax_mode )
472
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
473
+
474
+
475
+ @pytest .mark .xfail (reason = "`size_at` should be specified as a static argument" )
476
+ def test_random_concrete_shape_graph_input ():
477
+ rng = shared (np .random .RandomState (123 ))
478
+ size_at = at .scalar ()
479
+ out = at .random .normal (0 , 1 , size = size_at , rng = rng )
480
+ jax_fn = function ([size_at ], out , mode = jax_mode )
481
+ assert jax_fn (10 ).shape == (10 ,)
0 commit comments