Skip to content

Commit 5c97bc8

Browse files
committed
fix for all
1 parent a19cbc8 commit 5c97bc8

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def prod(x):
3939
elif isinstance(op.scalar_op, AND):
4040

4141
def all(x):
42-
return mx.all(a=x, axis=op.axis)
42+
return x.all(axis=op.axis)
4343

4444
return all
4545
elif isinstance(op.scalar_op, OR):

pytensor/link/mlx/dispatch/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Cos,
1818
Exp,
1919
Log,
20+
Log1p,
2021
Mul,
2122
Neg,
2223
Pow,
@@ -29,7 +30,6 @@
2930
Sub,
3031
Switch,
3132
TrueDiv,
32-
Log1p
3333
)
3434
from pytensor.scalar.math import Sigmoid
3535
from pytensor.tensor.elemwise import Elemwise
@@ -199,7 +199,7 @@ def neg(x):
199199
elif isinstance(op.scalar_op, AND):
200200

201201
def all(x):
202-
return mx.all(a=x, axis=op.axis)
202+
return x.all(axis=op.axis)
203203

204204
return all
205205
elif isinstance(op.scalar_op, OR):

tests/link/mlx/test_elemwise.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pytensor.tensor as pt
2+
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
3+
4+
5+
def test_all() -> None:
6+
x = pt.vector("x")
7+
8+
out = pt.all(x > 0)
9+
10+
x_test = mx.array([-1.0, 2.0, 3.0])
11+
12+
compare_mlx_and_py([x], out, [x_test])

0 commit comments

Comments
 (0)