Skip to content

Commit 8300fd4

Browse files
cetagostinijessegrabowski
authored andcommitted
Adding more operations for complex model
1 parent 3d144db commit 8300fd4

File tree

7 files changed

+303
-9
lines changed

7 files changed

+303
-9
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def tensor_from_scalar(x):
177177
@mlx_funcify.register(ScalarFromTensor)
178178
def mlx_funcify_ScalarFromTensor(op, **kwargs):
179179
def scalar_from_tensor(x):
180-
return x.reshape(-1)[0]
180+
return mx.array(x).reshape(-1)[0]
181181

182182
return scalar_from_tensor
183183

pytensor/link/mlx/dispatch/math.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
Cast,
1919
Cos,
2020
Exp,
21+
IntDiv,
2122
Invert,
23+
IsNan,
2224
Log,
2325
Log1p,
2426
Mul,
@@ -34,7 +36,7 @@
3436
Switch,
3537
TrueDiv,
3638
)
37-
from pytensor.scalar.math import Sigmoid
39+
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
3840
from pytensor.tensor.elemwise import Elemwise
3941
from pytensor.tensor.math import Dot
4042

@@ -113,6 +115,14 @@ def true_div(x, y):
113115
return true_div
114116

115117

118+
@mlx_funcify_Elemwise_scalar_op.register(IntDiv)
119+
def _(scalar_op):
120+
def int_div(x, y):
121+
return mx.floor_divide(x, y)
122+
123+
return int_div
124+
125+
116126
@mlx_funcify_Elemwise_scalar_op.register(Pow)
117127
def _(scalar_op):
118128
def pow(x, y):
@@ -309,11 +319,51 @@ def sigmoid(x):
309319
@mlx_funcify_Elemwise_scalar_op.register(Invert)
310320
def _(scalar_op):
311321
def invert(x):
312-
return ~x
322+
return mx.bitwise_invert(x)
313323

314324
return invert
315325

316326

327+
@mlx_funcify_Elemwise_scalar_op.register(IsNan)
328+
def _(scalar_op):
329+
def isnan(x):
330+
return mx.isnan(x)
331+
332+
return isnan
333+
334+
335+
@mlx_funcify_Elemwise_scalar_op.register(Erfc)
336+
def _(scalar_op):
337+
def erfc(x):
338+
return 1.0 - mx.erf(x)
339+
340+
return erfc
341+
342+
343+
@mlx_funcify_Elemwise_scalar_op.register(Erfcx)
344+
def _(scalar_op):
345+
def erfcx(x):
346+
return mx.exp(x * x) * (1.0 - mx.erf(x))
347+
348+
return erfcx
349+
350+
351+
@mlx_funcify_Elemwise_scalar_op.register(Softplus)
352+
def _(scalar_op):
353+
def softplus(x):
354+
# Numerically stable implementation of log(1 + exp(x))
355+
# Following the same logic as the original PyTensor implementation
356+
return mx.where(
357+
x < -37.0,
358+
mx.exp(x),
359+
mx.where(
360+
x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x)
361+
),
362+
)
363+
364+
return softplus
365+
366+
317367
@mlx_funcify.register(Elemwise)
318368
def mlx_funcify_Elemwise(op, node, **kwargs):
319369
# Dispatch to the appropriate scalar op handler

pytensor/link/mlx/dispatch/shape.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import mlx.core as mx
2+
13
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2-
from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape
4+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
35

46

57
@mlx_funcify.register(Shape)
@@ -30,3 +32,11 @@ def shape_i(x):
3032
return x.shape[op.i]
3133

3234
return shape_i
35+
36+
37+
@mlx_funcify.register(Reshape)
38+
def mlx_funcify_Reshape(op, **kwargs):
39+
def reshape(x, shp):
40+
return mx.reshape(x, shp)
41+
42+
return reshape

tests/link/mlx/test_basic.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
33

4+
import mlx.core as mx
45
import numpy as np
5-
import pytest
66

7+
import pytensor
8+
from pytensor import tensor as pt
79
from pytensor.compile.function import function
810
from pytensor.compile.mode import MLX, Mode
911
from pytensor.graph import RewriteDatabaseQuery
1012
from pytensor.graph.basic import Variable
1113
from pytensor.link.mlx import MLXLinker
14+
from pytensor.link.mlx.dispatch.core import mlx_funcify_ScalarFromTensor
1215

1316

14-
mx = pytest.importorskip("mlx.core")
15-
1617
optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude)
1718
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer)
1819
py_mode = Mode(linker="py", optimizer=None)
@@ -78,3 +79,52 @@ def compare_mlx_and_py(
7879
assert_fn(mlx_res, py_res)
7980

8081
return pytensor_mlx_fn, mlx_res
82+
83+
84+
def test_scalar_from_tensor_with_scalars():
85+
"""Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars.
86+
87+
This addresses the AttributeError that occurred when Python integers were
88+
passed to ScalarFromTensor instead of MLX arrays.
89+
"""
90+
scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None)
91+
92+
# Test with MLX array
93+
mlx_array = mx.array([42])
94+
result = scalar_from_tensor_func(mlx_array)
95+
assert result == 42
96+
97+
# Test with Python int (this used to fail)
98+
python_int = 42
99+
result = scalar_from_tensor_func(python_int)
100+
assert result == 42
101+
102+
# Test with Python float
103+
python_float = 3.14
104+
result = scalar_from_tensor_func(python_float)
105+
assert abs(result - 3.14) < 1e-6
106+
107+
# Test with NumPy scalar
108+
numpy_scalar = np.int32(123)
109+
result = scalar_from_tensor_func(numpy_scalar)
110+
assert result == 123
111+
112+
# Test with NumPy float scalar
113+
numpy_float = np.float32(2.71)
114+
result = scalar_from_tensor_func(numpy_float)
115+
assert abs(result - 2.71) < 1e-6
116+
117+
118+
def test_scalar_from_tensor_pytensor_integration():
119+
"""Test ScalarFromTensor in a PyTensor graph context."""
120+
# Create a 0-d tensor (scalar tensor)
121+
x = pt.as_tensor_variable(42)
122+
123+
# Apply ScalarFromTensor
124+
scalar_result = pt.scalar_from_tensor(x)
125+
126+
# Create function and test
127+
f = pytensor.function([], scalar_result, mode="MLX")
128+
result = f()
129+
130+
assert result == 42

tests/link/mlx/test_elemwise.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23

34
import pytensor.tensor as pt
@@ -11,3 +12,39 @@ def test_input(op) -> None:
1112
x_test = mx.array([1.0, 2.0, 3.0])
1213

1314
compare_mlx_and_py([x], out, [x_test])
15+
16+
17+
def test_new_elemwise_operations() -> None:
18+
"""Test new elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context"""
19+
x = pt.vector("x")
20+
y = pt.vector("y")
21+
22+
# Test int_div in an elemwise expression
23+
out_int_div = pt.int_div(x, y) + 1
24+
x_test = mx.array([10.0, 15.0, 20.0])
25+
y_test = mx.array([3.0, 4.0, 6.0])
26+
compare_mlx_and_py([x, y], out_int_div, [x_test, y_test])
27+
28+
# Test isnan in an elemwise expression
29+
z = pt.vector("z")
30+
out_isnan = pt.isnan(z).astype("float32") * 10
31+
z_test = mx.array([1.0, np.nan, 3.0])
32+
compare_mlx_and_py([z], out_isnan, [z_test])
33+
34+
# Test erfc in an elemwise expression
35+
w = pt.vector("w")
36+
out_erfc = pt.erfc(w) * 2.0
37+
w_test = mx.array([0.0, 0.5, 1.0])
38+
compare_mlx_and_py([w], out_erfc, [w_test])
39+
40+
# Test erfcx in an elemwise expression
41+
v = pt.vector("v")
42+
out_erfcx = pt.erfcx(v) + 0.1
43+
v_test = mx.array([0.0, 1.0, 2.0])
44+
compare_mlx_and_py([v], out_erfcx, [v_test])
45+
46+
# Test softplus in an elemwise expression
47+
u = pt.vector("u")
48+
out_softplus = pt.softplus(u) - 0.5
49+
u_test = mx.array([0.0, 1.0, -1.0])
50+
compare_mlx_and_py([u], out_softplus, [u_test])

tests/link/mlx/test_math.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_input(op) -> None:
7979
pytest.param(pt.eq, id="eq"),
8080
pytest.param(pt.neq, id="neq"),
8181
pytest.param(pt.true_div, id="true_div"),
82+
pytest.param(pt.int_div, id="int_div"),
8283
],
8384
)
8485
def test_elemwise_two_inputs(op) -> None:
@@ -90,6 +91,119 @@ def test_elemwise_two_inputs(op) -> None:
9091
compare_mlx_and_py([x, y], out, [x_test, y_test])
9192

9293

94+
def test_int_div_specific() -> None:
95+
"""Test integer division with specific test cases"""
96+
x = pt.vector("x")
97+
y = pt.vector("y")
98+
out = pt.int_div(x, y)
99+
100+
# Test with integers that demonstrate floor division behavior
101+
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
102+
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
103+
104+
compare_mlx_and_py([x, y], out, [x_test, y_test])
105+
106+
107+
def test_isnan() -> None:
108+
"""Test IsNan operation with various inputs including NaN values"""
109+
x = pt.vector("x")
110+
out = pt.isnan(x)
111+
112+
# Test with mix of normal values, NaN, and infinity
113+
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
114+
115+
compare_mlx_and_py([x], out, [x_test])
116+
117+
118+
def test_isnan_edge_cases() -> None:
119+
"""Test IsNan with edge cases"""
120+
x = pt.scalar("x")
121+
out = pt.isnan(x)
122+
123+
# Test individual cases
124+
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
125+
126+
for test_val in test_cases:
127+
x_test = test_val
128+
compare_mlx_and_py([x], out, [x_test])
129+
130+
131+
def test_erfc() -> None:
132+
"""Test complementary error function"""
133+
x = pt.vector("x")
134+
out = pt.erfc(x)
135+
136+
# Test with various values including negative, positive, and zero
137+
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
138+
139+
compare_mlx_and_py([x], out, [x_test])
140+
141+
142+
def test_erfc_extreme_values() -> None:
143+
"""Test erfc with extreme values"""
144+
x = pt.vector("x")
145+
out = pt.erfc(x)
146+
147+
# Test with larger values where erfc approaches 0 or 2
148+
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
149+
150+
# Use relaxed tolerance for extreme values due to numerical precision differences
151+
from functools import partial
152+
153+
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
154+
155+
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
156+
157+
158+
def test_erfcx() -> None:
159+
"""Test scaled complementary error function"""
160+
x = pt.vector("x")
161+
out = pt.erfcx(x)
162+
163+
# Test with positive values where erfcx is most numerically stable
164+
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
165+
166+
compare_mlx_and_py([x], out, [x_test])
167+
168+
169+
def test_erfcx_small_values() -> None:
170+
"""Test erfcx with small values"""
171+
x = pt.vector("x")
172+
out = pt.erfcx(x)
173+
174+
# Test with small values
175+
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
176+
177+
compare_mlx_and_py([x], out, [x_test])
178+
179+
180+
def test_softplus() -> None:
181+
"""Test softplus (log(1 + exp(x))) function"""
182+
x = pt.vector("x")
183+
out = pt.softplus(x)
184+
185+
# Test with normal range values
186+
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
187+
188+
compare_mlx_and_py([x], out, [x_test])
189+
190+
191+
def test_softplus_extreme_values() -> None:
192+
"""Test softplus with extreme values to verify numerical stability"""
193+
x = pt.vector("x")
194+
out = pt.softplus(x)
195+
196+
# Test with extreme values where different branches of the implementation are used
197+
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
198+
199+
# Use relaxed tolerance for extreme values due to numerical precision differences
200+
from functools import partial
201+
202+
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
203+
204+
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
205+
206+
93207
@pytest.mark.xfail(reason="Argmax not implemented yet")
94208
def test_mlx_max_and_argmax():
95209
# Test that a single output of a multi-output `Op` can be used as input to

0 commit comments

Comments
 (0)