Skip to content

Commit ab5037e

Browse files
Add more MLX dispatches (#1684)
* Add mlx extra ops * Add log_softmax dispatch for mlx * Fix split bug, add tests * Feedback
1 parent 17c675a commit ab5037e

File tree

9 files changed

+178
-7
lines changed

9 files changed

+178
-7
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@
1010
import pytensor.link.mlx.dispatch.signal
1111
import pytensor.link.mlx.dispatch.signal.conv
1212
import pytensor.link.mlx.dispatch.blockwise
13+
import pytensor.link.mlx.dispatch.extra_ops
14+
import pytensor.link.mlx.dispatch.sort
1315
# isort: on

pytensor/link/mlx/dispatch/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,17 @@ def split(x, axis, splits):
6161
# Resolve constants for significant performance improvement (14x speedup)
6262
if constant_axis is not None:
6363
axis = int(constant_axis)
64+
else:
65+
raise ValueError(
66+
"Symbolic axis is not supported in MLX Split implementation."
67+
)
6468

6569
if constant_splits is not None:
66-
splits = constant_splits
67-
cumsum_splits = np.cumsum(splits[:-1])
70+
splits_arr = mx.array(constant_splits)
6871
else:
69-
# Dynamic case - use MLX operations
7072
splits_arr = mx.array(splits)
71-
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
73+
74+
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
7275

7376
# Validation checks
7477
if len(splits) != op.len_splits:

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import singledispatch
22

33
import mlx.core as mx
4+
import mlx.nn as mlx_nn
45
import numpy as np
56

67
from pytensor.link.mlx.dispatch.basic import mlx_funcify
@@ -40,7 +41,7 @@
4041
)
4142
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
4243
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
43-
from pytensor.tensor.special import Softmax, SoftmaxGrad
44+
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4445

4546

4647
@mlx_funcify.register(DimShuffle)
@@ -142,6 +143,16 @@ def softmax_grad(dy, sm):
142143
return softmax_grad
143144

144145

146+
@mlx_funcify.register(LogSoftmax)
147+
def mlx_funcify_LogSoftmax(op, **kwargs):
148+
axis = op.axis
149+
150+
def log_softmax(x):
151+
return mlx_nn.log_softmax(x, axis=axis)
152+
153+
return log_softmax
154+
155+
145156
@mlx_funcify.register(Softplus)
146157
def mlx_funcify_Softplus(op, **kwargs):
147158
def softplus(x):
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4+
from pytensor.tensor.extra_ops import CumOp, Repeat
5+
6+
7+
@mlx_funcify.register(CumOp)
8+
def mlx_funcify_CumOp(op, **kwargs):
9+
axis = op.axis
10+
mode = op.mode
11+
12+
def cumop(x, axis=axis, mode=mode):
13+
match mode:
14+
case "add":
15+
return mx.cumsum(x, axis=axis)
16+
case "mul":
17+
return mx.cumprod(x, axis=axis)
18+
case _:
19+
raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX")
20+
21+
return cumop
22+
23+
24+
@mlx_funcify.register(Repeat)
25+
def jax_funcify_Repeat(op, **kwargs):
26+
axis = op.axis
27+
28+
def repeat(x, repeats, axis=axis):
29+
if not isinstance(repeats, int):
30+
raise NotImplementedError(
31+
"MLX repeat does not support sequence-valued repeat argument."
32+
)
33+
return mx.repeat(x, repeats, axis=axis)
34+
35+
return repeat

pytensor/link/mlx/dispatch/sort.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import warnings
2+
3+
import mlx.core as mx
4+
5+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6+
from pytensor.tensor.sort import ArgSortOp, SortOp
7+
8+
9+
@mlx_funcify.register(SortOp)
10+
def mlx_funcify_Sort(op, **kwargs):
11+
kind = op.kind
12+
if kind != "quicksort":
13+
warnings.warn(
14+
message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be "
15+
f"ignored.",
16+
category=UserWarning,
17+
)
18+
19+
def sort(x, axis):
20+
return mx.sort(x, axis=axis)
21+
22+
return sort
23+
24+
25+
@mlx_funcify.register(ArgSortOp)
26+
def mlx_funcify_ArgSort(op, **kwargs):
27+
kind = op.kind
28+
if kind != "quicksort":
29+
warnings.warn(
30+
message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be "
31+
f"ignored.",
32+
category=UserWarning,
33+
)
34+
35+
def argsort(x, axis):
36+
return mx.argsort(x, axis=axis)
37+
38+
return argsort

tests/link/mlx/test_core.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
import pytest
33

44
import pytensor
5+
from pytensor import config
56
from pytensor import tensor as pt
67
from pytensor.tensor.basic import Alloc
7-
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
8+
from tests.link.mlx.test_basic import (
9+
compare_mlx_and_py,
10+
compile_mode,
11+
mlx_mode_no_compile,
12+
mx,
13+
)
814

915

1016
def test_alloc_with_different_shape_types():
@@ -137,3 +143,24 @@ def test_empty_dynamic_shape():
137143
"used inside compiled functions",
138144
):
139145
f_compiled(3, 4)
146+
147+
148+
def test_split_const_axis_const_splits_compiled():
149+
x = pt.vector("x")
150+
splits = [2, 3]
151+
outs = pt.split(x, splits, len(splits), axis=0)
152+
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
153+
154+
155+
def test_split_dynamic_axis_const_splits():
156+
x = pt.matrix("x")
157+
axis = pt.scalar("axis", dtype="int64")
158+
splits = [1, 2, 3]
159+
outs = pt.split(x, splits, len(splits), axis=axis)
160+
161+
test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
162+
163+
with pytest.raises(
164+
ValueError, match="Symbolic axis is not supported in MLX Split implementation"
165+
):
166+
compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)])

tests/link/mlx/test_elemwise.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytensor.tensor.math import max as pt_max
3131
from pytensor.tensor.math import min as pt_min
3232
from pytensor.tensor.math import sum as pt_sum
33-
from pytensor.tensor.special import SoftmaxGrad, softmax
33+
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
3434
from pytensor.tensor.type import matrix, vector, vectors
3535
from tests.link.mlx.test_basic import compare_mlx_and_py
3636

@@ -97,6 +97,15 @@ def test_softmax_grad(axis):
9797
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
9898

9999

100+
@pytest.mark.parametrize("axis", [None, 0, 1])
101+
def test_logsoftmax(axis):
102+
x = matrix("x")
103+
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
104+
out = log_softmax(x, axis=axis)
105+
106+
compare_mlx_and_py([x], [out], [x_test_value])
107+
108+
100109
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
101110
@pytest.mark.parametrize("axis", [0, 1])
102111
def test_logsumexp_benchmark(size, axis, benchmark):

tests/link/mlx/test_extra_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.configdefaults import config
5+
from pytensor.tensor import extra_ops as pt_extra_ops
6+
from pytensor.tensor.type import matrix
7+
from tests.link.mlx.test_basic import compare_mlx_and_py
8+
9+
10+
mx = pytest.importorskip("mlx.core")
11+
12+
13+
def test_extra_ops():
14+
a = matrix("a")
15+
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
16+
17+
out = pt_extra_ops.cumsum(a, axis=0)
18+
compare_mlx_and_py([a], [out], [a_test])
19+
20+
out = pt_extra_ops.cumprod(a, axis=1)
21+
compare_mlx_and_py([a], [out], [a_test])
22+
23+
out = pt_extra_ops.repeat(a, 3, axis=1)
24+
compare_mlx_and_py([a], [out], [a_test])

tests/link/mlx/test_sort.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.tensor.sort import argsort, sort
5+
from pytensor.tensor.type import matrix
6+
from tests.link.mlx.test_basic import compare_mlx_and_py
7+
8+
9+
@pytest.mark.parametrize("axis", [None, -1])
10+
@pytest.mark.parametrize("func", (sort, argsort))
11+
def test_sort(func, axis):
12+
x = matrix("x", shape=(2, 2), dtype="float64")
13+
out = func(x, axis=axis)
14+
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
15+
compare_mlx_and_py([x], [out], [arr])
16+
17+
18+
def test_sort_invalid_kind_warning():
19+
x = matrix("x", shape=(2, 2), dtype="float64")
20+
z = sort(x, axis=-1, kind="mergesort")
21+
with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"):
22+
z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX")

0 commit comments

Comments
 (0)