Skip to content

Commit 9527f6c

Browse files
cetagostinijessegrabowski
authored andcommitted
Changing synth test
1 parent 481e3ad commit 9527f6c

File tree

1 file changed

+81
-30
lines changed

1 file changed

+81
-30
lines changed

tests/link/mlx/test_basic.py

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Basic tests for the MLX backend.
3+
"""
14
from collections.abc import Callable, Iterable
25
from functools import partial
36

@@ -13,7 +16,6 @@
1316
from pytensor.link.mlx import MLXLinker
1417
from pytensor.link.mlx.dispatch.core import (
1518
mlx_funcify_Alloc,
16-
mlx_funcify_ScalarFromTensor,
1719
)
1820
from pytensor.tensor.basic import Alloc
1921

@@ -87,53 +89,102 @@ def compare_mlx_and_py(
8789
return pytensor_mlx_fn, mlx_res
8890

8991

90-
def test_scalar_from_tensor_with_scalars():
91-
"""Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars.
92+
def test_scalar_from_tensor_matrix_indexing():
93+
"""Test ScalarFromTensor with matrix element extraction."""
94+
# Matrix element extraction is a common real-world scenario
95+
matrix = pt.matrix("matrix", dtype="float32")
96+
element = matrix[0, 0] # Creates 0-d tensor
9297

93-
This addresses the AttributeError that occurred when Python integers were
94-
passed to ScalarFromTensor instead of MLX arrays.
95-
"""
96-
scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None)
98+
f = pytensor.function([matrix], element, mode="MLX")
9799

98-
# Test with MLX array
99-
mlx_array = mx.array([42])
100-
result = scalar_from_tensor_func(mlx_array)
101-
assert result == 42
100+
test_matrix = np.array([[42.0, 1.0], [2.0, 3.0]], dtype=np.float32)
101+
result = f(test_matrix)
102+
103+
assert float(result) == 42.0
104+
assert isinstance(result, mx.array)
105+
106+
107+
def test_scalar_from_tensor_reduction_operations():
108+
"""Test ScalarFromTensor with reduction operations that produce scalars."""
109+
# Test vector sum reduction
110+
vector = pt.vector("vector", dtype="float32")
111+
sum_result = pt.sum(vector)
112+
113+
f = pytensor.function([vector], sum_result, mode="MLX")
114+
test_vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
115+
result = f(test_vector)
116+
117+
assert float(result) == 10.0
118+
119+
# Test matrix mean reduction
120+
matrix = pt.matrix("matrix", dtype="float32")
121+
mean_result = pt.mean(matrix)
122+
123+
f2 = pytensor.function([matrix], mean_result, mode="MLX")
124+
test_matrix = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32)
125+
result = f2(test_matrix)
126+
127+
assert float(result) == 5.0
102128

103-
# Test with Python int (this used to fail)
104-
python_int = 42
105-
result = scalar_from_tensor_func(python_int)
106-
assert result == 42
107129

108-
# Test with Python float
109-
python_float = 3.14
110-
result = scalar_from_tensor_func(python_float)
111-
assert abs(result - 3.14) < 1e-6
130+
def test_scalar_from_tensor_conditional_operations():
131+
"""Test ScalarFromTensor with conditional operations."""
132+
x = pt.scalar("x", dtype="float32")
133+
y = pt.scalar("y", dtype="float32")
134+
135+
# Switch operation may create 0-d tensors
136+
max_val = pt.switch(x > y, x, y)
137+
138+
f = pytensor.function([x, y], max_val, mode="MLX")
112139

113-
# Test with NumPy scalar
114-
numpy_scalar = np.int32(123)
115-
result = scalar_from_tensor_func(numpy_scalar)
116-
assert result == 123
140+
# Test both branches
141+
result1 = f(5.0, 3.0)
142+
assert float(result1) == 5.0
117143

118-
# Test with NumPy float scalar
119-
numpy_float = np.float32(2.71)
120-
result = scalar_from_tensor_func(numpy_float)
121-
assert abs(result - 2.71) < 1e-6
144+
result2 = f(2.0, 7.0)
145+
assert float(result2) == 7.0
146+
147+
148+
def test_scalar_from_tensor_multiple_dtypes():
149+
"""Test ScalarFromTensor with different data types."""
150+
# Test different dtypes that might require scalar extraction
151+
for dtype in ["float32", "int32", "int64"]:
152+
x = pt.vector("x", dtype=dtype)
153+
# Use max reduction to create 0-d tensor
154+
max_val = pt.max(x)
155+
156+
f = pytensor.function([x], max_val, mode="MLX", allow_input_downcast=True)
157+
158+
if dtype.startswith("float"):
159+
test_data = np.array([1.5, 3.7, 2.1], dtype=dtype)
160+
expected = 3.7
161+
else:
162+
test_data = np.array([10, 30, 20], dtype=dtype)
163+
expected = 30
164+
165+
result = f(test_data)
166+
assert abs(float(result) - expected) < 1e-5
122167

123168

124169
def test_scalar_from_tensor_pytensor_integration():
125-
"""Test ScalarFromTensor in a PyTensor graph context."""
170+
"""Test ScalarFromTensor in a complete PyTensor graph context.
171+
172+
This test uses symbolic variables (not constants) to ensure the MLX backend
173+
actually executes the ScalarFromTensor operation rather than having it
174+
optimized away during compilation.
175+
"""
126176
# Create a symbolic scalar input to actually test MLX execution
127177
x = pt.scalar("x", dtype="int64")
128178

129-
# Apply ScalarFromTensor
179+
# Apply ScalarFromTensor - this creates a graph that forces execution
130180
scalar_result = pt.scalar_from_tensor(x)
131181

132-
# Create function and test
182+
# Create function and test with actual MLX backend execution
133183
f = pytensor.function([x], scalar_result, mode="MLX")
134184
result = f(42)
135185

136186
assert result == 42
187+
assert isinstance(result, mx.array)
137188

138189

139190
def test_alloc_with_different_shape_types():

0 commit comments

Comments
 (0)