Skip to content

Commit 433a2cb

Browse files
Move alloc tests to test_core.py
1 parent bad7c90 commit 433a2cb

File tree

2 files changed

+114
-105
lines changed

2 files changed

+114
-105
lines changed

tests/link/mlx/test_basic.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pytensor.graph.basic import Variable
1717
from pytensor.link.mlx import MLXLinker
1818
from pytensor.raise_op import assert_op
19-
from pytensor.tensor.basic import Alloc
2019

2120

2221
mx = pytest.importorskip("mlx.core")
@@ -188,110 +187,6 @@ def test_scalar_from_tensor_pytensor_integration():
188187
assert isinstance(result, mx.array)
189188

190189

191-
def test_alloc_with_different_shape_types():
192-
"""Test Alloc works with different types of shape parameters.
193-
194-
This addresses the TypeError that occurred when shape parameters
195-
contained MLX arrays instead of Python integers.
196-
"""
197-
from pytensor.link.mlx.dispatch.core import (
198-
mlx_funcify_Alloc,
199-
)
200-
201-
# Create a mock node (we don't need a real node for this test)
202-
class MockNode:
203-
pass
204-
205-
alloc_func = mlx_funcify_Alloc(Alloc(), MockNode())
206-
x = mx.array(5.0)
207-
208-
# Test with Python ints
209-
result = alloc_func(x, 3, 4)
210-
assert result.shape == (3, 4)
211-
assert float(result[0, 0]) == 5.0
212-
213-
# Test with MLX arrays (this used to fail)
214-
result = alloc_func(x, mx.array(3), mx.array(4))
215-
assert result.shape == (3, 4)
216-
assert float(result[0, 0]) == 5.0
217-
218-
# Test with mixed types
219-
result = alloc_func(x, 3, mx.array(4))
220-
assert result.shape == (3, 4)
221-
assert float(result[0, 0]) == 5.0
222-
223-
224-
def test_alloc_pytensor_integration():
225-
"""Test Alloc in a PyTensor graph context."""
226-
# Test basic constant shape allocation
227-
x = pt.scalar("x", dtype="float32")
228-
result = pt.alloc(x, 3, 4)
229-
230-
f = pytensor.function([x], result, mode="MLX")
231-
output = f(5.0)
232-
233-
assert output.shape == (3, 4)
234-
assert float(output[0, 0]) == 5.0
235-
236-
237-
def test_alloc_compilation_limitation():
238-
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
239-
240-
# Create variables
241-
x = pt.scalar("x", dtype="float32")
242-
s1 = pt.scalar("s1", dtype="int64")
243-
s2 = pt.scalar("s2", dtype="int64")
244-
245-
# Create Alloc operation with dynamic shapes
246-
result = pt.alloc(x, s1, s2)
247-
248-
# Create function with non-compiled MLX mode
249-
f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile)
250-
251-
# Test that it works with concrete values (non-compiled context)
252-
output = f(5.0, 3, 4)
253-
assert output.shape == (3, 4)
254-
assert np.allclose(output, 5.0)
255-
256-
# Test that compilation fails with helpful error
257-
compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode)
258-
259-
with pytest.raises(ValueError) as exc_info:
260-
compiled_f(5.0, 3, 4)
261-
262-
error_msg = str(exc_info.value)
263-
assert "MLX compilation limitation" in error_msg
264-
assert "Alloc operations with dynamic shapes" in error_msg
265-
assert "cannot be used inside compiled functions" in error_msg
266-
assert "Workarounds:" in error_msg
267-
assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg
268-
assert "Use static shapes when possible" in error_msg
269-
assert "Move Alloc operations outside compiled functions" in error_msg
270-
271-
272-
def test_alloc_static_shapes_compilation():
273-
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
274-
# Create a scenario with static shapes that should work
275-
x = pt.scalar("x", dtype="float32")
276-
277-
# Use constant shape - this should work even in compilation
278-
result = pt.alloc(x, 3, 4) # Static shapes
279-
280-
# Test both compiled and non-compiled modes
281-
f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile)
282-
f_compiled = pytensor.function([x], result, mode=compile_mode)
283-
284-
# Both should work
285-
output_normal = f_normal(5.0)
286-
output_compiled = f_compiled(5.0)
287-
288-
assert output_normal.shape == (3, 4)
289-
assert output_compiled.shape == (3, 4)
290-
assert np.allclose(output_normal, 5.0)
291-
assert np.allclose(output_compiled, 5.0)
292-
assert np.allclose(output_normal, output_compiled)
293-
294-
295190
def test_mlx_float64_auto_casting():
296191
"""Test MLX automatic casting of float64 to float32 with warnings."""
297192
import warnings

tests/link/mlx/test_core.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
from pytensor import tensor as pt
6+
from pytensor.tensor.basic import Alloc
7+
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
8+
9+
10+
def test_alloc_with_different_shape_types():
11+
"""Test Alloc works with different types of shape parameters.
12+
13+
This addresses the TypeError that occurred when shape parameters
14+
contained MLX arrays instead of Python integers.
15+
"""
16+
from pytensor.link.mlx.dispatch.core import (
17+
mlx_funcify_Alloc,
18+
)
19+
20+
# Create a mock node (we don't need a real node for this test)
21+
class MockNode:
22+
def __init__(self):
23+
self.op = Alloc()
24+
self.inputs = None
25+
self.outputs = None
26+
27+
alloc_func = mlx_funcify_Alloc(Alloc(), MockNode())
28+
x = mx.array(5.0)
29+
30+
# Test with Python ints
31+
result = alloc_func(x, 3, 4)
32+
assert result.shape == (3, 4)
33+
assert float(result[0, 0]) == 5.0
34+
35+
# Test with MLX arrays (this used to fail)
36+
result = alloc_func(x, mx.array(3), mx.array(4))
37+
assert result.shape == (3, 4)
38+
assert float(result[0, 0]) == 5.0
39+
40+
# Test with mixed types
41+
result = alloc_func(x, 3, mx.array(4))
42+
assert result.shape == (3, 4)
43+
assert float(result[0, 0]) == 5.0
44+
45+
46+
def test_alloc_pytensor_integration():
47+
"""Test Alloc in a PyTensor graph context."""
48+
# Test basic constant shape allocation
49+
x = pt.scalar("x", dtype="float32")
50+
result = pt.alloc(x, 3, 4)
51+
52+
f = pytensor.function([x], result, mode="MLX")
53+
output = f(5.0)
54+
55+
assert output.shape == (3, 4)
56+
assert float(output[0, 0]) == 5.0
57+
58+
59+
def test_alloc_compilation_limitation():
60+
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
61+
62+
# Create variables
63+
x = pt.scalar("x", dtype="float32")
64+
s1 = pt.scalar("s1", dtype="int64")
65+
s2 = pt.scalar("s2", dtype="int64")
66+
67+
# Create Alloc operation with dynamic shapes
68+
result = pt.alloc(x, s1, s2)
69+
70+
# Create function with non-compiled MLX mode
71+
f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile)
72+
73+
# Test that it works with concrete values (non-compiled context)
74+
output = f(5.0, 3, 4)
75+
assert output.shape == (3, 4)
76+
assert np.allclose(output, 5.0)
77+
78+
# Test that compilation fails with helpful error
79+
compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode)
80+
81+
with pytest.raises(ValueError) as exc_info:
82+
compiled_f(5.0, 3, 4)
83+
84+
error_msg = str(exc_info.value)
85+
assert "MLX compilation limitation" in error_msg
86+
assert "Alloc operations with dynamic shapes" in error_msg
87+
assert "cannot be used inside compiled functions" in error_msg
88+
assert "Workarounds:" in error_msg
89+
assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg
90+
assert "Use static shapes when possible" in error_msg
91+
assert "Move Alloc operations outside compiled functions" in error_msg
92+
93+
94+
def test_alloc_static_shapes_compilation():
95+
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
96+
# Create a scenario with static shapes that should work
97+
x = pt.scalar("x", dtype="float32")
98+
99+
# Use constant shape - this should work even in compilation
100+
result = pt.alloc(x, 3, 4) # Static shapes
101+
102+
# Test both compiled and non-compiled modes
103+
f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile)
104+
f_compiled = pytensor.function([x], result, mode=compile_mode)
105+
106+
# Both should work
107+
output_normal = f_normal(5.0)
108+
output_compiled = f_compiled(5.0)
109+
110+
assert output_normal.shape == (3, 4)
111+
assert output_compiled.shape == (3, 4)
112+
assert np.allclose(output_normal, 5.0)
113+
assert np.allclose(output_compiled, 5.0)
114+
assert np.allclose(output_normal, output_compiled)

0 commit comments

Comments
 (0)