11import mlx .core as mx
22
33from pytensor .link .mlx .dispatch import mlx_funcify
4-
4+ from pytensor . scalar . basic import Add , Cos , Exp , Log , Mul , Sin , Sub
55from pytensor .tensor .elemwise import Elemwise
66from pytensor .tensor .math import Dot
7- from pytensor .scalar .basic import Add , Mul , Sub , Exp , Log , Sin , Cos
87
98
109@mlx_funcify .register (Dot )
@@ -14,42 +13,50 @@ def dot(x, y):
1413
1514 return dot
1615
16+
1717@mlx_funcify .register (Elemwise )
1818def mlx_funcify_Elemwise (op , ** kwargs ):
1919 if isinstance (op .scalar_op , Add ):
20+
2021 def add (x , y ):
2122 return mx .add (x , y )
2223
2324 return add
2425 elif isinstance (op .scalar_op , Sub ):
26+
2527 def sub (x , y ):
26- return mx .sub (x , y )
28+ return mx .subtract (x , y )
2729
2830 return sub
2931 elif isinstance (op .scalar_op , Mul ):
32+
3033 def mul (x , y ):
31- return mx .mul (x , y )
34+ return mx .multiply (x , y )
3235
3336 return mul
3437 elif isinstance (op .scalar_op , Exp ):
38+
3539 def exp (x ):
3640 return mx .exp (x )
3741
3842 return exp
3943 elif isinstance (op .scalar_op , Log ):
44+
4045 def log (x ):
4146 return mx .log (x )
4247
4348 return log
4449 elif isinstance (op .scalar_op , Sin ):
50+
4551 def sin (x ):
4652 return mx .sin (x )
47-
53+
4854 return sin
4955 elif isinstance (op .scalar_op , Cos ):
56+
5057 def cos (x ):
5158 return mx .cos (x )
5259
5360 return cos
5461 else :
55- raise NotImplementedError (f"MLX does not support { op .scalar_op } " )
62+ raise NotImplementedError (f"MLX does not support { op .scalar_op } " )
0 commit comments