File tree Expand file tree Collapse file tree 3 files changed +15
-3
lines changed
pytensor/link/mlx/dispatch Expand file tree Collapse file tree 3 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -39,7 +39,7 @@ def prod(x):
3939 elif isinstance (op .scalar_op , AND ):
4040
4141 def all (x ):
42- return mx .all (a = x , axis = op .axis )
42+ return x .all (axis = op .axis )
4343
4444 return all
4545 elif isinstance (op .scalar_op , OR ):
Original file line number Diff line number Diff line change 1717 Cos ,
1818 Exp ,
1919 Log ,
20+ Log1p ,
2021 Mul ,
2122 Neg ,
2223 Pow ,
2930 Sub ,
3031 Switch ,
3132 TrueDiv ,
32- Log1p
3333)
3434from pytensor .scalar .math import Sigmoid
3535from pytensor .tensor .elemwise import Elemwise
@@ -199,7 +199,7 @@ def neg(x):
199199 elif isinstance (op .scalar_op , AND ):
200200
201201 def all (x ):
202- return mx .all (a = x , axis = op .axis )
202+ return x .all (axis = op .axis )
203203
204204 return all
205205 elif isinstance (op .scalar_op , OR ):
Original file line number Diff line number Diff line change 1+ import pytensor .tensor as pt
2+ from tests .link .mlx .test_basic import compare_mlx_and_py , mx
3+
4+
5+ def test_all () -> None :
6+ x = pt .vector ("x" )
7+
8+ out = pt .all (x > 0 )
9+
10+ x_test = mx .array ([- 1.0 , 2.0 , 3.0 ])
11+
12+ compare_mlx_and_py ([x ], out , [x_test ])
You can’t perform that action at this time.
0 commit comments