|
16 | 16 | from pytensor.graph.basic import Variable |
17 | 17 | from pytensor.link.mlx import MLXLinker |
18 | 18 | from pytensor.raise_op import assert_op |
19 | | -from pytensor.tensor.basic import Alloc |
20 | 19 |
|
21 | 20 |
|
22 | 21 | mx = pytest.importorskip("mlx.core") |
@@ -188,110 +187,6 @@ def test_scalar_from_tensor_pytensor_integration(): |
188 | 187 | assert isinstance(result, mx.array) |
189 | 188 |
|
190 | 189 |
|
191 | | -def test_alloc_with_different_shape_types(): |
192 | | - """Test Alloc works with different types of shape parameters. |
193 | | -
|
194 | | - This addresses the TypeError that occurred when shape parameters |
195 | | - contained MLX arrays instead of Python integers. |
196 | | - """ |
197 | | - from pytensor.link.mlx.dispatch.core import ( |
198 | | - mlx_funcify_Alloc, |
199 | | - ) |
200 | | - |
201 | | - # Create a mock node (we don't need a real node for this test) |
202 | | - class MockNode: |
203 | | - pass |
204 | | - |
205 | | - alloc_func = mlx_funcify_Alloc(Alloc(), MockNode()) |
206 | | - x = mx.array(5.0) |
207 | | - |
208 | | - # Test with Python ints |
209 | | - result = alloc_func(x, 3, 4) |
210 | | - assert result.shape == (3, 4) |
211 | | - assert float(result[0, 0]) == 5.0 |
212 | | - |
213 | | - # Test with MLX arrays (this used to fail) |
214 | | - result = alloc_func(x, mx.array(3), mx.array(4)) |
215 | | - assert result.shape == (3, 4) |
216 | | - assert float(result[0, 0]) == 5.0 |
217 | | - |
218 | | - # Test with mixed types |
219 | | - result = alloc_func(x, 3, mx.array(4)) |
220 | | - assert result.shape == (3, 4) |
221 | | - assert float(result[0, 0]) == 5.0 |
222 | | - |
223 | | - |
224 | | -def test_alloc_pytensor_integration(): |
225 | | - """Test Alloc in a PyTensor graph context.""" |
226 | | - # Test basic constant shape allocation |
227 | | - x = pt.scalar("x", dtype="float32") |
228 | | - result = pt.alloc(x, 3, 4) |
229 | | - |
230 | | - f = pytensor.function([x], result, mode="MLX") |
231 | | - output = f(5.0) |
232 | | - |
233 | | - assert output.shape == (3, 4) |
234 | | - assert float(output[0, 0]) == 5.0 |
235 | | - |
236 | | - |
237 | | -def test_alloc_compilation_limitation(): |
238 | | - """Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts.""" |
239 | | - |
240 | | - # Create variables |
241 | | - x = pt.scalar("x", dtype="float32") |
242 | | - s1 = pt.scalar("s1", dtype="int64") |
243 | | - s2 = pt.scalar("s2", dtype="int64") |
244 | | - |
245 | | - # Create Alloc operation with dynamic shapes |
246 | | - result = pt.alloc(x, s1, s2) |
247 | | - |
248 | | - # Create function with non-compiled MLX mode |
249 | | - f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile) |
250 | | - |
251 | | - # Test that it works with concrete values (non-compiled context) |
252 | | - output = f(5.0, 3, 4) |
253 | | - assert output.shape == (3, 4) |
254 | | - assert np.allclose(output, 5.0) |
255 | | - |
256 | | - # Test that compilation fails with helpful error |
257 | | - compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode) |
258 | | - |
259 | | - with pytest.raises(ValueError) as exc_info: |
260 | | - compiled_f(5.0, 3, 4) |
261 | | - |
262 | | - error_msg = str(exc_info.value) |
263 | | - assert "MLX compilation limitation" in error_msg |
264 | | - assert "Alloc operations with dynamic shapes" in error_msg |
265 | | - assert "cannot be used inside compiled functions" in error_msg |
266 | | - assert "Workarounds:" in error_msg |
267 | | - assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg |
268 | | - assert "Use static shapes when possible" in error_msg |
269 | | - assert "Move Alloc operations outside compiled functions" in error_msg |
270 | | - |
271 | | - |
272 | | -def test_alloc_static_shapes_compilation(): |
273 | | - """Test that Alloc operations with static shapes work fine in compiled contexts.""" |
274 | | - # Create a scenario with static shapes that should work |
275 | | - x = pt.scalar("x", dtype="float32") |
276 | | - |
277 | | - # Use constant shape - this should work even in compilation |
278 | | - result = pt.alloc(x, 3, 4) # Static shapes |
279 | | - |
280 | | - # Test both compiled and non-compiled modes |
281 | | - f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile) |
282 | | - f_compiled = pytensor.function([x], result, mode=compile_mode) |
283 | | - |
284 | | - # Both should work |
285 | | - output_normal = f_normal(5.0) |
286 | | - output_compiled = f_compiled(5.0) |
287 | | - |
288 | | - assert output_normal.shape == (3, 4) |
289 | | - assert output_compiled.shape == (3, 4) |
290 | | - assert np.allclose(output_normal, 5.0) |
291 | | - assert np.allclose(output_compiled, 5.0) |
292 | | - assert np.allclose(output_normal, output_compiled) |
293 | | - |
294 | | - |
295 | 190 | def test_mlx_float64_auto_casting(): |
296 | 191 | """Test MLX automatic casting of float64 to float32 with warnings.""" |
297 | 192 | import warnings |
|
0 commit comments