Skip to content

Commit 877d79f

Browse files
committed
Changes
1 parent 82bb964 commit 877d79f

File tree

4 files changed

+164
-8
lines changed

4 files changed

+164
-8
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,9 @@
22
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
33

44
import pytensor.link.mlx.dispatch.math
5-
# isort: on
5+
import pytensor.link.mlx.dispatch.basic
6+
import pytensor.link.mlx.dispatch.elemwise
7+
import pytensor.link.mlx.dispatch.shape
8+
import pytensor.link.mlx.dispatch.subtensor
9+
import pytensor.link.mlx.dispatch.core
10+
# isort: on

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4-
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
4+
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch
55
from pytensor.tensor.elemwise import CAReduce, DimShuffle
66
from pytensor.tensor.special import Softmax, SoftmaxGrad
77

pytensor/link/mlx/dispatch/math.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub
55
from pytensor.tensor.elemwise import Elemwise
66
from pytensor.tensor.math import Dot
7+
from pytensor.scalar.math import Sigmoid
8+
from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ
79

810

911
@mlx_funcify.register(Dot)
@@ -17,9 +19,11 @@ def dot(x, y):
1719
@mlx_funcify.register(Elemwise)
1820
def mlx_funcify_Elemwise(op, **kwargs):
1921
if isinstance(op.scalar_op, Add):
20-
21-
def add(x, y):
22-
return mx.add(x, y)
22+
def add(*args):
23+
result = args[0]
24+
for arg in args[1:]:
25+
result = mx.add(result, arg)
26+
return result
2327

2428
return add
2529
elif isinstance(op.scalar_op, Sub):
@@ -29,9 +33,11 @@ def sub(x, y):
2933

3034
return sub
3135
elif isinstance(op.scalar_op, Mul):
32-
33-
def mul(x, y):
34-
return mx.multiply(x, y)
36+
def mul(*args):
37+
result = args[0]
38+
for arg in args[1:]:
39+
result = mx.multiply(result, arg)
40+
return result
3541

3642
return mul
3743
elif isinstance(op.scalar_op, Exp):
@@ -58,5 +64,40 @@ def cos(x):
5864
return mx.cos(x)
5965

6066
return cos
67+
elif isinstance(op.scalar_op, Sigmoid):
68+
def sigmoid(x):
69+
return mx.sigmoid(x)
70+
71+
return sigmoid
72+
elif isinstance(op.scalar_op, LE):
73+
def le(x, y):
74+
return mx.less_equal(x, y)
75+
76+
return le
77+
elif isinstance(op.scalar_op, LT):
78+
def lt(x, y):
79+
return mx.less(x, y)
80+
81+
return lt
82+
elif isinstance(op.scalar_op, GE):
83+
def ge(x, y):
84+
return mx.greater_equal(x, y)
85+
86+
return ge
87+
elif isinstance(op.scalar_op, GT):
88+
def gt(x, y):
89+
return mx.greater(x, y)
90+
91+
return gt
92+
elif isinstance(op.scalar_op, EQ):
93+
def eq(x, y):
94+
return mx.equal(x, y)
95+
96+
return eq
97+
elif isinstance(op.scalar_op, NEQ):
98+
def neq(x, y):
99+
return mx.not_equal(x, y)
100+
101+
return neq
61102
else:
62103
raise NotImplementedError(f"MLX does not support {op.scalar_op}")
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2+
from pytensor.tensor.subtensor import (
3+
AdvancedIncSubtensor,
4+
AdvancedIncSubtensor1,
5+
AdvancedSubtensor,
6+
AdvancedSubtensor1,
7+
IncSubtensor,
8+
Subtensor,
9+
indices_from_subtensor,
10+
)
11+
from pytensor.tensor.type_other import MakeSlice
12+
13+
14+
BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean
15+
masks. In some cases, however, it is possible to re-express your model
16+
in a form that MLX can compile:
17+
18+
>>> import pytensor.tensor as pt
19+
>>> x_pt = pt.vector('x')
20+
>>> y_pt = x_pt[x_pt > 0].sum()
21+
22+
can be re-expressed as:
23+
24+
>>> import pytensor.tensor as pt
25+
>>> x_pt = pt.vector('x')
26+
>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum()
27+
"""
28+
29+
DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic
30+
slice length.
31+
"""
32+
33+
34+
@mlx_funcify.register(Subtensor)
35+
@mlx_funcify.register(AdvancedSubtensor)
36+
@mlx_funcify.register(AdvancedSubtensor1)
37+
def mlx_funcify_Subtensor(op, node, **kwargs):
38+
idx_list = getattr(op, "idx_list", None)
39+
40+
def subtensor(x, *ilists):
41+
indices = indices_from_subtensor(ilists, idx_list)
42+
if len(indices) == 1:
43+
indices = indices[0]
44+
45+
return x.__getitem__(indices)
46+
47+
return subtensor
48+
49+
50+
@mlx_funcify.register(IncSubtensor)
51+
@mlx_funcify.register(AdvancedIncSubtensor1)
52+
def mlx_funcify_IncSubtensor(op, node, **kwargs):
53+
idx_list = getattr(op, "idx_list", None)
54+
55+
if getattr(op, "set_instead_of_inc", False):
56+
57+
def mlx_fn(x, indices, y):
58+
if not op.inplace:
59+
x = x.copy()
60+
x[indices] = y
61+
return x
62+
63+
else:
64+
65+
def mlx_fn(x, indices, y):
66+
if not op.inplace:
67+
x = x.copy()
68+
x[indices] += y
69+
return x
70+
71+
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
72+
indices = indices_from_subtensor(ilist, idx_list)
73+
if len(indices) == 1:
74+
indices = indices[0]
75+
76+
return mlx_fn(x, indices, y)
77+
78+
return incsubtensor
79+
80+
81+
@mlx_funcify.register(AdvancedIncSubtensor)
82+
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
83+
if getattr(op, "set_instead_of_inc", False):
84+
85+
def mlx_fn(x, indices, y):
86+
if not op.inplace:
87+
x = x.copy()
88+
x[indices] = y
89+
return x
90+
91+
else:
92+
93+
def mlx_fn(x, indices, y):
94+
if not op.inplace:
95+
x = x.copy()
96+
x[indices] += y
97+
return x
98+
99+
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
100+
return mlx_fn(x, ilist, y)
101+
102+
return advancedincsubtensor
103+
104+
105+
@mlx_funcify.register(MakeSlice)
106+
def mlx_funcify_MakeSlice(op, **kwargs):
107+
def makeslice(*x):
108+
return slice(*x)
109+
110+
return makeslice

0 commit comments

Comments
 (0)