Skip to content

Commit 9d3eca8

Browse files
committed
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor into pr/1365
2 parents 294c271 + 0812c55 commit 9d3eca8

File tree

5 files changed

+146
-44
lines changed

5 files changed

+146
-44
lines changed

pytensor/link/mlx/dispatch/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from copy import deepcopy
23
from functools import singledispatch
34
from types import NoneType
45

@@ -58,7 +59,7 @@ def mlx_funcify_FunctionGraph(
5859
@mlx_funcify.register(DeepCopyOp)
5960
def mlx_funcify_DeepCopyOp(op, **kwargs):
6061
def deepcopyop(x):
61-
return x.copy()
62+
return deepcopy(x)
6263

6364
return deepcopyop
6465

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
44
from pytensor.scalar import Softplus
5-
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
65
from pytensor.tensor.elemwise import CAReduce, DimShuffle
76
from pytensor.tensor.special import Softmax, SoftmaxGrad
87

@@ -24,44 +23,53 @@ def dimshuffle(x):
2423

2524
@mlx_funcify.register(CAReduce)
2625
def mlx_funcify_CAReduce(op, **kwargs):
27-
if isinstance(op.scalar_op, Add):
28-
29-
def sum(x):
30-
return mx.sum(x, axis=op.axis)
31-
32-
return sum
33-
elif isinstance(op.scalar_op, Mul):
34-
35-
def prod(x):
36-
return mx.prod(x, axis=op.axis)
37-
38-
return prod
39-
elif isinstance(op.scalar_op, AND):
40-
41-
def all(x):
42-
return x.all(axis=op.axis)
43-
44-
return all
45-
elif isinstance(op.scalar_op, OR):
46-
47-
def any(x):
48-
return mx.any(x, axis=op.axis)
49-
50-
return any
51-
elif isinstance(op.scalar_op, ScalarMaximum):
52-
53-
def max(x):
54-
return mx.max(x, axis=op.axis)
55-
56-
return max
57-
elif isinstance(op.scalar_op, ScalarMinimum):
58-
59-
def min(x):
60-
return mx.min(x, axis=op.axis)
61-
62-
return min
63-
else:
64-
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
26+
axis = op.axis
27+
op_nfunc_spec = getattr(op, "nfunc_spec", None)
28+
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
29+
scalar_op_name = getattr(op.scalar_op, "name", None)
30+
scalar_op_identity = getattr(op.scalar_op, "identity", None)
31+
acc_dtype = getattr(op, "acc_dtype", None)
32+
33+
def careduce(x):
34+
nonlocal \
35+
axis, \
36+
op_nfunc_spec, \
37+
scalar_nfunc_spec, \
38+
scalar_op_name, \
39+
scalar_op_identity, \
40+
acc_dtype
41+
42+
if axis is None:
43+
axis = list(range(x.ndim))
44+
45+
if acc_dtype is None:
46+
acc_dtype = x.dtype.type
47+
48+
if op_nfunc_spec:
49+
mlx_op = getattr(mx, op_nfunc_spec[0])
50+
return mlx_op(x, axis=axis)
51+
return mlx_op(x, axis=axis).astype(acc_dtype)
52+
53+
# The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
54+
# there isn't one), so we use this fallback approach
55+
if scalar_nfunc_spec:
56+
scalar_fn_name = scalar_nfunc_spec[0]
57+
elif scalar_op_name:
58+
scalar_fn_name = scalar_op_name
59+
60+
to_reduce = sorted(axis, reverse=True)
61+
62+
if to_reduce:
63+
raise NotImplementedError("Not implemented yet")
64+
# In this case, we need to use the `jax.lax` function (if there
65+
# is one), and not the `jnp` version.
66+
mlx_op = getattr(mx, scalar_fn_name)
67+
init_value = mx.array(scalar_op_identity, dtype=acc_dtype)
68+
return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype)
69+
else:
70+
return x
71+
72+
return careduce
6573

6674

6775
@mlx_funcify.register(Softmax)

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
13
from pytensor.link.mlx.dispatch.basic import mlx_funcify
24
from pytensor.tensor.subtensor import (
35
AdvancedIncSubtensor,
@@ -24,6 +26,7 @@ def subtensor(x, *ilists):
2426

2527
return subtensor
2628

29+
2730
@mlx_funcify.register(AdvancedSubtensor)
2831
@mlx_funcify.register(AdvancedSubtensor1)
2932
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@@ -48,15 +51,15 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
4851

4952
def mlx_fn(x, indices, y):
5053
if not op.inplace:
51-
x = x.copy()
54+
x = deepcopy(x)
5255
x[indices] = y
5356
return x
5457

5558
else:
5659

5760
def mlx_fn(x, indices, y):
5861
if not op.inplace:
59-
x = x.copy()
62+
x = deepcopy(x)
6063
x[indices] += y
6164
return x
6265

@@ -76,15 +79,15 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
7679

7780
def mlx_fn(x, indices, y):
7881
if not op.inplace:
79-
x = x.copy()
82+
x = deepcopy(x)
8083
x[indices] = y
8184
return x
8285

8386
else:
8487

8588
def mlx_fn(x, indices, y):
8689
if not op.inplace:
87-
x = x.copy()
90+
x = deepcopy(x)
8891
x[indices] += y
8992
return x
9093

tests/link/mlx/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.tensor.math import Argmax, Max
67
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
78

89

@@ -87,3 +88,14 @@ def test_elemwise_two_inputs(op) -> None:
8788
x_test = mx.array([1.0, 2.0, 3.0])
8889
y_test = mx.array([4.0, 5.0, 6.0])
8990
compare_mlx_and_py([x, y], out, [x_test, y_test])
91+
92+
93+
@pytest.mark.xfail(reason="Argmax not implemented yet")
94+
def test_mlx_max_and_argmax():
95+
# Test that a single output of a multi-output `Op` can be used as input to
96+
# another `Op`
97+
x = pt.dvector()
98+
mx = Max([0])(x)
99+
amx = Argmax([0])(x)
100+
out = mx * amx
101+
compare_mlx_and_py([x], [out], [np.r_[1, 2]])

tests/link/mlx/test_shape.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor.compile.ops import DeepCopyOp, ViewOp
6+
from pytensor.configdefaults import config
7+
from pytensor.tensor.shape import Shape, Shape_i, reshape
8+
from pytensor.tensor.type import iscalar, vector
9+
from tests.link.mlx.test_basic import compare_mlx_and_py
10+
11+
12+
@pytest.mark.xfail(reason="Shape Op is not supported yet")
13+
def test_mlx_shape_ops():
14+
x_np = np.zeros((20, 3))
15+
x = Shape()(pt.as_tensor_variable(x_np))
16+
17+
compare_mlx_and_py([], [x], [], must_be_device_array=False)
18+
19+
x = Shape_i(1)(pt.as_tensor_variable(x_np))
20+
21+
compare_mlx_and_py([], [x], [], must_be_device_array=False)
22+
23+
24+
@pytest.mark.xfail(reason="Shape Op is not supported yet")
25+
def test_mlx_specify_shape():
26+
in_pt = pt.matrix("in")
27+
x = pt.specify_shape(in_pt, (4, None))
28+
compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
29+
30+
# When used to assert two arrays have similar shapes
31+
in_pt = pt.matrix("in")
32+
shape_pt = pt.matrix("shape")
33+
x = pt.specify_shape(in_pt, shape_pt.shape)
34+
35+
compare_mlx_and_py(
36+
[in_pt, shape_pt],
37+
[x],
38+
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
39+
)
40+
41+
42+
@pytest.mark.xfail(reason="Reshape Op is not supported yet")
43+
def test_mlx_Reshape_constant():
44+
a = vector("a")
45+
x = reshape(a, (2, 2))
46+
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
47+
48+
49+
@pytest.mark.xfail(reason="Reshape Op is not supported yet")
50+
def test_mlx_Reshape_concrete_shape():
51+
"""MLX should compile when a concrete value is passed for the `shape` parameter."""
52+
a = vector("a")
53+
x = reshape(a, a.shape)
54+
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
55+
56+
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
57+
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
58+
59+
60+
@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument")
61+
def test_mlx_Reshape_shape_graph_input():
62+
a = vector("a")
63+
shape_pt = iscalar("b")
64+
x = reshape(a, (shape_pt, shape_pt))
65+
compare_mlx_and_py(
66+
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
67+
)
68+
69+
70+
@pytest.mark.xfail(reason="ViewOp Op is not supported yet")
71+
def test_mlx_compile_ops():
72+
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
73+
compare_mlx_and_py([], [x], [])
74+
75+
x_np = np.zeros((20, 1, 1))
76+
x = ViewOp()(pt.as_tensor_variable(x_np))
77+
78+
compare_mlx_and_py([], [x], [])

0 commit comments

Comments
 (0)