@@ -73,23 +73,18 @@ def test_alloc_compilation_limitation():
7373 # Test that it works with concrete values (non-compiled context)
7474 output = f (5.0 , 3 , 4 )
7575 assert output .shape == (3 , 4 )
76- assert np .allclose (output , 5.0 )
76+ np .testing . assert_allclose (output , 5.0 )
7777
7878 # Test that compilation fails with helpful error
7979 compiled_f = pytensor .function ([x , s1 , s2 ], result , mode = compile_mode )
8080
81- with pytest .raises (ValueError ) as exc_info :
81+ with pytest .raises (
82+ ValueError ,
83+ match = "MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
84+ "used inside compiled functions" ,
85+ ):
8286 compiled_f (5.0 , 3 , 4 )
8387
84- error_msg = str (exc_info .value )
85- assert "MLX compilation limitation" in error_msg
86- assert "Alloc operations with dynamic shapes" in error_msg
87- assert "cannot be used inside compiled functions" in error_msg
88- assert "Workarounds:" in error_msg
89- assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg
90- assert "Use static shapes when possible" in error_msg
91- assert "Move Alloc operations outside compiled functions" in error_msg
92-
9388
9489def test_alloc_static_shapes_compilation ():
9590 """Test that Alloc operations with static shapes work fine in compiled contexts."""
@@ -109,6 +104,36 @@ def test_alloc_static_shapes_compilation():
109104
110105 assert output_normal .shape == (3 , 4 )
111106 assert output_compiled .shape == (3 , 4 )
112- assert np .allclose (output_normal , 5.0 )
113- assert np .allclose (output_compiled , 5.0 )
114- assert np .allclose (output_normal , output_compiled )
107+ np .testing .assert_allclose (output_normal , 5.0 )
108+ np .testing .assert_allclose (output_compiled , 5.0 )
109+ np .testing .assert_allclose (output_normal , output_compiled )
110+
111+
112+ def test_empty_static_shape ():
113+ result = pt .empty ((3 , 4 ), dtype = "float32" )
114+
115+ f = pytensor .function ([], result , mode = "MLX" )
116+ output = f ()
117+
118+ assert output .shape == (3 , 4 )
119+ np .testing .assert_allclose (output , 0.0 )
120+
121+
122+ def test_empty_dynamic_shape ():
123+ s1 = pt .scalar ("s1" , dtype = "int64" )
124+ s2 = pt .scalar ("s2" , dtype = "int64" )
125+ result = pt .empty ((s1 , s2 ), dtype = "float32" )
126+
127+ f = pytensor .function ([s1 , s2 ], result , mode = mlx_mode_no_compile )
128+ output = f (3 , 4 )
129+
130+ assert output .shape == (3 , 4 )
131+ np .testing .assert_allclose (output , 0.0 )
132+
133+ f_compiled = pytensor .function ([s1 , s2 ], result , mode = compile_mode )
134+ with pytest .raises (
135+ ValueError ,
136+ match = "MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
137+ "used inside compiled functions" ,
138+ ):
139+ f_compiled (3 , 4 )
0 commit comments