@@ -73,23 +73,18 @@ def test_alloc_compilation_limitation():
73
73
# Test that it works with concrete values (non-compiled context)
74
74
output = f (5.0 , 3 , 4 )
75
75
assert output .shape == (3 , 4 )
76
- assert np .allclose (output , 5.0 )
76
+ np .testing . assert_allclose (output , 5.0 )
77
77
78
78
# Test that compilation fails with helpful error
79
79
compiled_f = pytensor .function ([x , s1 , s2 ], result , mode = compile_mode )
80
80
81
- with pytest .raises (ValueError ) as exc_info :
81
+ with pytest .raises (
82
+ ValueError ,
83
+ match = "MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
84
+ "used inside compiled functions" ,
85
+ ):
82
86
compiled_f (5.0 , 3 , 4 )
83
87
84
- error_msg = str (exc_info .value )
85
- assert "MLX compilation limitation" in error_msg
86
- assert "Alloc operations with dynamic shapes" in error_msg
87
- assert "cannot be used inside compiled functions" in error_msg
88
- assert "Workarounds:" in error_msg
89
- assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg
90
- assert "Use static shapes when possible" in error_msg
91
- assert "Move Alloc operations outside compiled functions" in error_msg
92
-
93
88
94
89
def test_alloc_static_shapes_compilation ():
95
90
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
@@ -109,6 +104,36 @@ def test_alloc_static_shapes_compilation():
109
104
110
105
assert output_normal .shape == (3 , 4 )
111
106
assert output_compiled .shape == (3 , 4 )
112
- assert np .allclose (output_normal , 5.0 )
113
- assert np .allclose (output_compiled , 5.0 )
114
- assert np .allclose (output_normal , output_compiled )
107
+ np .testing .assert_allclose (output_normal , 5.0 )
108
+ np .testing .assert_allclose (output_compiled , 5.0 )
109
+ np .testing .assert_allclose (output_normal , output_compiled )
110
+
111
+
112
+ def test_empty_static_shape ():
113
+ result = pt .empty ((3 , 4 ), dtype = "float32" )
114
+
115
+ f = pytensor .function ([], result , mode = "MLX" )
116
+ output = f ()
117
+
118
+ assert output .shape == (3 , 4 )
119
+ np .testing .assert_allclose (output , 0.0 )
120
+
121
+
122
+ def test_empty_dynamic_shape ():
123
+ s1 = pt .scalar ("s1" , dtype = "int64" )
124
+ s2 = pt .scalar ("s2" , dtype = "int64" )
125
+ result = pt .empty ((s1 , s2 ), dtype = "float32" )
126
+
127
+ f = pytensor .function ([s1 , s2 ], result , mode = mlx_mode_no_compile )
128
+ output = f (3 , 4 )
129
+
130
+ assert output .shape == (3 , 4 )
131
+ np .testing .assert_allclose (output , 0.0 )
132
+
133
+ f_compiled = pytensor .function ([s1 , s2 ], result , mode = compile_mode )
134
+ with pytest .raises (
135
+ ValueError ,
136
+ match = "MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
137
+ "used inside compiled functions" ,
138
+ ):
139
+ f_compiled (3 , 4 )
0 commit comments