Skip to content

Commit 67bb8da

Browse files
committed
correct the function name
1 parent 516b595 commit 67bb8da

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

pytensor/link/mlx/dispatch/math.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
4-
4+
from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub
55
from pytensor.tensor.elemwise import Elemwise
66
from pytensor.tensor.math import Dot
7-
from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos
87

98

109
@mlx_funcify.register(Dot)
@@ -14,42 +13,50 @@ def dot(x, y):
1413

1514
return dot
1615

16+
1717
@mlx_funcify.register(Elemwise)
1818
def mlx_funcify_Elemwise(op, **kwargs):
1919
if isinstance(op.scalar_op, Add):
20+
2021
def add(x, y):
2122
return mx.add(x, y)
2223

2324
return add
2425
elif isinstance(op.scalar_op, Sub):
26+
2527
def sub(x, y):
26-
return mx.sub(x, y)
28+
return mx.subtract(x, y)
2729

2830
return sub
2931
elif isinstance(op.scalar_op, Mul):
32+
3033
def mul(x, y):
31-
return mx.mul(x, y)
34+
return mx.multiply(x, y)
3235

3336
return mul
3437
elif isinstance(op.scalar_op, Exp):
38+
3539
def exp(x):
3640
return mx.exp(x)
3741

3842
return exp
3943
elif isinstance(op.scalar_op, Log):
44+
4045
def log(x):
4146
return mx.log(x)
4247

4348
return log
4449
elif isinstance(op.scalar_op, Sin):
50+
4551
def sin(x):
4652
return mx.sin(x)
47-
53+
4854
return sin
4955
elif isinstance(op.scalar_op, Cos):
56+
5057
def cos(x):
5158
return mx.cos(x)
5259

5360
return cos
5461
else:
55-
raise NotImplementedError(f"MLX does not support {op.scalar_op}")
62+
raise NotImplementedError(f"MLX does not support {op.scalar_op}")

0 commit comments

Comments
 (0)