|
| 1 | +""" |
| 2 | +Basic tests for the MLX backend. |
| 3 | +""" |
1 | 4 | from collections.abc import Callable, Iterable
|
2 | 5 | from functools import partial
|
3 | 6 |
|
|
13 | 16 | from pytensor.link.mlx import MLXLinker
|
14 | 17 | from pytensor.link.mlx.dispatch.core import (
|
15 | 18 | mlx_funcify_Alloc,
|
16 |
| - mlx_funcify_ScalarFromTensor, |
17 | 19 | )
|
18 | 20 | from pytensor.tensor.basic import Alloc
|
19 | 21 |
|
@@ -87,53 +89,102 @@ def compare_mlx_and_py(
|
87 | 89 | return pytensor_mlx_fn, mlx_res
|
88 | 90 |
|
89 | 91 |
|
90 |
| -def test_scalar_from_tensor_with_scalars(): |
91 |
| - """Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars. |
| 92 | +def test_scalar_from_tensor_matrix_indexing(): |
| 93 | + """Test ScalarFromTensor with matrix element extraction.""" |
| 94 | + # Matrix element extraction is a common real-world scenario |
| 95 | + matrix = pt.matrix("matrix", dtype="float32") |
| 96 | + element = matrix[0, 0] # Creates 0-d tensor |
92 | 97 |
|
93 |
| - This addresses the AttributeError that occurred when Python integers were |
94 |
| - passed to ScalarFromTensor instead of MLX arrays. |
95 |
| - """ |
96 |
| - scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None) |
| 98 | + f = pytensor.function([matrix], element, mode="MLX") |
97 | 99 |
|
98 |
| - # Test with MLX array |
99 |
| - mlx_array = mx.array([42]) |
100 |
| - result = scalar_from_tensor_func(mlx_array) |
101 |
| - assert result == 42 |
| 100 | + test_matrix = np.array([[42.0, 1.0], [2.0, 3.0]], dtype=np.float32) |
| 101 | + result = f(test_matrix) |
| 102 | + |
| 103 | + assert float(result) == 42.0 |
| 104 | + assert isinstance(result, mx.array) |
| 105 | + |
| 106 | + |
| 107 | +def test_scalar_from_tensor_reduction_operations(): |
| 108 | + """Test ScalarFromTensor with reduction operations that produce scalars.""" |
| 109 | + # Test vector sum reduction |
| 110 | + vector = pt.vector("vector", dtype="float32") |
| 111 | + sum_result = pt.sum(vector) |
| 112 | + |
| 113 | + f = pytensor.function([vector], sum_result, mode="MLX") |
| 114 | + test_vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| 115 | + result = f(test_vector) |
| 116 | + |
| 117 | + assert float(result) == 10.0 |
| 118 | + |
| 119 | + # Test matrix mean reduction |
| 120 | + matrix = pt.matrix("matrix", dtype="float32") |
| 121 | + mean_result = pt.mean(matrix) |
| 122 | + |
| 123 | + f2 = pytensor.function([matrix], mean_result, mode="MLX") |
| 124 | + test_matrix = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) |
| 125 | + result = f2(test_matrix) |
| 126 | + |
| 127 | + assert float(result) == 5.0 |
102 | 128 |
|
103 |
| - # Test with Python int (this used to fail) |
104 |
| - python_int = 42 |
105 |
| - result = scalar_from_tensor_func(python_int) |
106 |
| - assert result == 42 |
107 | 129 |
|
108 |
| - # Test with Python float |
109 |
| - python_float = 3.14 |
110 |
| - result = scalar_from_tensor_func(python_float) |
111 |
| - assert abs(result - 3.14) < 1e-6 |
| 130 | +def test_scalar_from_tensor_conditional_operations(): |
| 131 | + """Test ScalarFromTensor with conditional operations.""" |
| 132 | + x = pt.scalar("x", dtype="float32") |
| 133 | + y = pt.scalar("y", dtype="float32") |
| 134 | + |
| 135 | + # Switch operation may create 0-d tensors |
| 136 | + max_val = pt.switch(x > y, x, y) |
| 137 | + |
| 138 | + f = pytensor.function([x, y], max_val, mode="MLX") |
112 | 139 |
|
113 |
| - # Test with NumPy scalar |
114 |
| - numpy_scalar = np.int32(123) |
115 |
| - result = scalar_from_tensor_func(numpy_scalar) |
116 |
| - assert result == 123 |
| 140 | + # Test both branches |
| 141 | + result1 = f(5.0, 3.0) |
| 142 | + assert float(result1) == 5.0 |
117 | 143 |
|
118 |
| - # Test with NumPy float scalar |
119 |
| - numpy_float = np.float32(2.71) |
120 |
| - result = scalar_from_tensor_func(numpy_float) |
121 |
| - assert abs(result - 2.71) < 1e-6 |
| 144 | + result2 = f(2.0, 7.0) |
| 145 | + assert float(result2) == 7.0 |
| 146 | + |
| 147 | + |
| 148 | +def test_scalar_from_tensor_multiple_dtypes(): |
| 149 | + """Test ScalarFromTensor with different data types.""" |
| 150 | + # Test different dtypes that might require scalar extraction |
| 151 | + for dtype in ["float32", "int32", "int64"]: |
| 152 | + x = pt.vector("x", dtype=dtype) |
| 153 | + # Use max reduction to create 0-d tensor |
| 154 | + max_val = pt.max(x) |
| 155 | + |
| 156 | + f = pytensor.function([x], max_val, mode="MLX", allow_input_downcast=True) |
| 157 | + |
| 158 | + if dtype.startswith("float"): |
| 159 | + test_data = np.array([1.5, 3.7, 2.1], dtype=dtype) |
| 160 | + expected = 3.7 |
| 161 | + else: |
| 162 | + test_data = np.array([10, 30, 20], dtype=dtype) |
| 163 | + expected = 30 |
| 164 | + |
| 165 | + result = f(test_data) |
| 166 | + assert abs(float(result) - expected) < 1e-5 |
122 | 167 |
|
123 | 168 |
|
124 | 169 | def test_scalar_from_tensor_pytensor_integration():
|
125 |
| - """Test ScalarFromTensor in a PyTensor graph context.""" |
| 170 | + """Test ScalarFromTensor in a complete PyTensor graph context. |
| 171 | +
|
| 172 | + This test uses symbolic variables (not constants) to ensure the MLX backend |
| 173 | + actually executes the ScalarFromTensor operation rather than having it |
| 174 | + optimized away during compilation. |
| 175 | + """ |
126 | 176 | # Create a symbolic scalar input to actually test MLX execution
|
127 | 177 | x = pt.scalar("x", dtype="int64")
|
128 | 178 |
|
129 |
| - # Apply ScalarFromTensor |
| 179 | + # Apply ScalarFromTensor - this creates a graph that forces execution |
130 | 180 | scalar_result = pt.scalar_from_tensor(x)
|
131 | 181 |
|
132 |
| - # Create function and test |
| 182 | + # Create function and test with actual MLX backend execution |
133 | 183 | f = pytensor.function([x], scalar_result, mode="MLX")
|
134 | 184 | result = f(42)
|
135 | 185 |
|
136 | 186 | assert result == 42
|
| 187 | + assert isinstance(result, mx.array) |
137 | 188 |
|
138 | 189 |
|
139 | 190 | def test_alloc_with_different_shape_types():
|
|
0 commit comments