11import mlx .core as mx
22
33from pytensor .link .mlx .dispatch import mlx_funcify
4- from pytensor .scalar .basic import Add , Cos , Exp , Log , Mul , Sin , Sub
4+ from pytensor .scalar .basic import (
5+ EQ ,
6+ GE ,
7+ GT ,
8+ LE ,
9+ LT ,
10+ NEQ ,
11+ Add ,
12+ Cos ,
13+ Exp ,
14+ Log ,
15+ Mul ,
16+ Pow ,
17+ Sin ,
18+ Sub ,
19+ Switch ,
20+ TrueDiv ,
21+ )
22+ from pytensor .scalar .math import Sigmoid
523from pytensor .tensor .elemwise import Elemwise
624from pytensor .tensor .math import Dot
7- from pytensor .scalar .math import Sigmoid
8- from pytensor .scalar .basic import Add , Mul , Sub , Exp , Log , Sin , Cos , LE , LT , GE , GT , EQ , NEQ
925
1026
1127@mlx_funcify .register (Dot )
@@ -19,6 +35,7 @@ def dot(x, y):
1935@mlx_funcify .register (Elemwise )
2036def mlx_funcify_Elemwise (op , ** kwargs ):
2137 if isinstance (op .scalar_op , Add ):
38+
2239 def add (* args ):
2340 result = args [0 ]
2441 for arg in args [1 :]:
@@ -33,6 +50,7 @@ def sub(x, y):
3350
3451 return sub
3552 elif isinstance (op .scalar_op , Mul ):
53+
3654 def mul (* args ):
3755 result = args [0 ]
3856 for arg in args [1 :]:
@@ -65,39 +83,64 @@ def cos(x):
6583
6684 return cos
6785 elif isinstance (op .scalar_op , Sigmoid ):
86+
6887 def sigmoid (x ):
6988 return mx .sigmoid (x )
7089
7190 return sigmoid
7291 elif isinstance (op .scalar_op , LE ):
92+
7393 def le (x , y ):
7494 return mx .less_equal (x , y )
7595
7696 return le
7797 elif isinstance (op .scalar_op , LT ):
98+
7899 def lt (x , y ):
79100 return mx .less (x , y )
80101
81102 return lt
82103 elif isinstance (op .scalar_op , GE ):
104+
83105 def ge (x , y ):
84106 return mx .greater_equal (x , y )
85107
86108 return ge
87109 elif isinstance (op .scalar_op , GT ):
110+
88111 def gt (x , y ):
89112 return mx .greater (x , y )
90113
91114 return gt
92115 elif isinstance (op .scalar_op , EQ ):
116+
93117 def eq (x , y ):
94118 return mx .equal (x , y )
95119
96120 return eq
97121 elif isinstance (op .scalar_op , NEQ ):
122+
98123 def neq (x , y ):
99124 return mx .not_equal (x , y )
100125
101126 return neq
127+ elif isinstance (op .scalar_op , Switch ):
128+
129+ def switch (cond , x , y ):
130+ return mx .where (cond , x , y )
131+
132+ return switch
133+ elif isinstance (op .scalar_op , Pow ):
134+
135+ def pow (x , y ):
136+ return mx .power (x , y )
137+
138+ return pow
139+ elif isinstance (op .scalar_op , TrueDiv ):
140+
141+ def true_div (x , y ):
142+ return mx .divide (x , y )
143+
144+ return true_div
102145 else :
103146 raise NotImplementedError (f"MLX does not support { op .scalar_op } " )
0 commit comments