Skip to content

Commit fafedd6

Browse files
committed
Toma tu tomate William
1 parent 877d79f commit fafedd6

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
import pytensor.link.mlx.dispatch.shape
88
import pytensor.link.mlx.dispatch.subtensor
99
import pytensor.link.mlx.dispatch.core
10-
# isort: on
10+
# isort: on

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4-
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch
4+
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
55
from pytensor.tensor.elemwise import CAReduce, DimShuffle
66
from pytensor.tensor.special import Softmax, SoftmaxGrad
77

pytensor/link/mlx/dispatch/math.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
4-
from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub
4+
from pytensor.scalar.basic import (
5+
EQ,
6+
GE,
7+
GT,
8+
LE,
9+
LT,
10+
NEQ,
11+
Add,
12+
Cos,
13+
Exp,
14+
Log,
15+
Mul,
16+
Pow,
17+
Sin,
18+
Sub,
19+
Switch,
20+
TrueDiv,
21+
)
22+
from pytensor.scalar.math import Sigmoid
523
from pytensor.tensor.elemwise import Elemwise
624
from pytensor.tensor.math import Dot
7-
from pytensor.scalar.math import Sigmoid
8-
from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ
925

1026

1127
@mlx_funcify.register(Dot)
@@ -19,6 +35,7 @@ def dot(x, y):
1935
@mlx_funcify.register(Elemwise)
2036
def mlx_funcify_Elemwise(op, **kwargs):
2137
if isinstance(op.scalar_op, Add):
38+
2239
def add(*args):
2340
result = args[0]
2441
for arg in args[1:]:
@@ -33,6 +50,7 @@ def sub(x, y):
3350

3451
return sub
3552
elif isinstance(op.scalar_op, Mul):
53+
3654
def mul(*args):
3755
result = args[0]
3856
for arg in args[1:]:
@@ -65,39 +83,64 @@ def cos(x):
6583

6684
return cos
6785
elif isinstance(op.scalar_op, Sigmoid):
86+
6887
def sigmoid(x):
6988
return mx.sigmoid(x)
7089

7190
return sigmoid
7291
elif isinstance(op.scalar_op, LE):
92+
7393
def le(x, y):
7494
return mx.less_equal(x, y)
7595

7696
return le
7797
elif isinstance(op.scalar_op, LT):
98+
7899
def lt(x, y):
79100
return mx.less(x, y)
80101

81102
return lt
82103
elif isinstance(op.scalar_op, GE):
104+
83105
def ge(x, y):
84106
return mx.greater_equal(x, y)
85107

86108
return ge
87109
elif isinstance(op.scalar_op, GT):
110+
88111
def gt(x, y):
89112
return mx.greater(x, y)
90113

91114
return gt
92115
elif isinstance(op.scalar_op, EQ):
116+
93117
def eq(x, y):
94118
return mx.equal(x, y)
95119

96120
return eq
97121
elif isinstance(op.scalar_op, NEQ):
122+
98123
def neq(x, y):
99124
return mx.not_equal(x, y)
100125

101126
return neq
127+
elif isinstance(op.scalar_op, Switch):
128+
129+
def switch(cond, x, y):
130+
return mx.where(cond, x, y)
131+
132+
return switch
133+
elif isinstance(op.scalar_op, Pow):
134+
135+
def pow(x, y):
136+
return mx.power(x, y)
137+
138+
return pow
139+
elif isinstance(op.scalar_op, TrueDiv):
140+
141+
def true_div(x, y):
142+
return mx.divide(x, y)
143+
144+
return true_div
102145
else:
103146
raise NotImplementedError(f"MLX does not support {op.scalar_op}")

pytensor/link/mlx/dispatch/shape.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
21
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2+
from pytensor.tensor.shape import SpecifyShape
3+
34

45
@mlx_funcify.register(SpecifyShape)
56
def mlx_funcify_SpecifyShape(op, node, **kwargs):
@@ -12,4 +13,4 @@ def specifyshape(x, *shape):
1213
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
1314
return x
1415

15-
return specifyshape
16+
return specifyshape

0 commit comments

Comments
 (0)