Skip to content

Commit 934306f

Browse files
williambdeancetagostiniricardoV94jessegrabowski
authored
Add MLX backend (#1365)
Co-authored-by: Carlos Trujillo <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 92c3b49 commit 934306f

25 files changed

+2886
-4
lines changed

.github/workflows/test.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ jobs:
8181
install-numba: [0]
8282
install-jax: [0]
8383
install-torch: [0]
84+
install-mlx: [0]
8485
install-xarray: [0]
8586
part:
8687
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
@@ -106,6 +107,7 @@ jobs:
106107
install-numba: 0
107108
install-jax: 0
108109
install-torch: 0
110+
install-mlx: 0
109111
install-xarray: 0
110112
- install-numba: 1
111113
os: "ubuntu-latest"
@@ -149,7 +151,16 @@ jobs:
149151
fast-compile: 0
150152
float32: 0
151153
part: "tests/xtensor"
152-
- os: macos-15
154+
- os: "macos-15"
155+
python-version: "3.11"
156+
fast-compile: 0
157+
float32: 0
158+
install-mlx: 1
159+
install-numba: 0
160+
install-jax: 0
161+
install-torch: 0
162+
part: "tests/link/mlx"
163+
- os: "macos-15"
153164
python-version: "3.13"
154165
fast-compile: 0
155166
float32: 0
@@ -194,6 +205,7 @@ jobs:
194205
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
195206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
196207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208+
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
197209
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
198210
199211
pip install -e ./
@@ -210,6 +222,7 @@ jobs:
210222
INSTALL_JAX: ${{ matrix.install-jax }}
211223
INSTALL_TORCH: ${{ matrix.install-torch}}
212224
INSTALL_XARRAY: ${{ matrix.install-xarray }}
225+
INSTALL_MLX: ${{ matrix.install-mlx }}
213226
OS: ${{ matrix.os}}
214227

215228
- name: Run tests

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ __pycache__
2727
\#*\#
2828
build
2929
compiled/*.cpp
30-
core.*
3130
cutils_ext.cpp
3231
dist
3332
doc/.build/

doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb

Lines changed: 436 additions & 0 deletions
Large diffs are not rendered by default.

pytensor/compile/mode.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.link.basic import Linker, PerformLinker
2828
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2929
from pytensor.link.jax.linker import JAXLinker
30+
from pytensor.link.mlx.linker import MLXLinker
3031
from pytensor.link.numba.linker import NumbaLinker
3132
from pytensor.link.pytorch.linker import PytorchLinker
3233
from pytensor.link.vm import VMLinker
@@ -50,6 +51,7 @@
5051
"jax": JAXLinker(),
5152
"pytorch": PytorchLinker(),
5253
"numba": NumbaLinker(),
54+
"mlx": MLXLinker(),
5355
}
5456

5557

@@ -504,13 +506,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
504506
),
505507
)
506508

509+
MLX = Mode(
510+
MLXLinker(),
511+
RewriteDatabaseQuery(
512+
include=["fast_run"],
513+
exclude=[
514+
"cxx_only",
515+
"BlasOpt",
516+
"fusion",
517+
"inplace",
518+
"scan_save_mem_prealloc",
519+
],
520+
),
521+
)
522+
507523

508524
predefined_modes = {
509525
"FAST_COMPILE": FAST_COMPILE,
510526
"FAST_RUN": FAST_RUN,
511527
"JAX": JAX,
512528
"NUMBA": NUMBA,
513529
"PYTORCH": PYTORCH,
530+
"MLX": MLX,
514531
}
515532

516533
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}

pytensor/link/mlx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.mlx.linker import MLXLinker
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# isort: off
2+
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
3+
4+
import pytensor.link.mlx.dispatch.math
5+
import pytensor.link.mlx.dispatch.basic
6+
import pytensor.link.mlx.dispatch.elemwise
7+
import pytensor.link.mlx.dispatch.shape
8+
import pytensor.link.mlx.dispatch.subtensor
9+
import pytensor.link.mlx.dispatch.core
10+
import pytensor.link.mlx.dispatch.signal
11+
import pytensor.link.mlx.dispatch.signal.conv
12+
import pytensor.link.mlx.dispatch.blockwise
13+
# isort: on
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import warnings
2+
from copy import deepcopy
3+
from functools import singledispatch
4+
from types import NoneType
5+
6+
import mlx.core as mx
7+
import numpy as np
8+
9+
from pytensor.compile.ops import DeepCopyOp
10+
from pytensor.graph import Constant
11+
from pytensor.graph.fg import FunctionGraph
12+
from pytensor.link.utils import fgraph_to_python
13+
from pytensor.raise_op import Assert, CheckAndRaise
14+
15+
16+
@singledispatch
17+
def mlx_typify(data, **kwargs):
18+
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")
19+
20+
21+
@mlx_typify.register(np.ndarray)
22+
def mlx_typify_tensor(data, dtype=None, **kwargs):
23+
return mx.array(data, dtype=dtype)
24+
25+
26+
@mlx_typify.register(slice)
27+
@mlx_typify.register(NoneType)
28+
@mlx_typify.register(mx.array)
29+
def mlx_typify_no_conversion_needed(data, **kwargs):
30+
return data
31+
32+
33+
@mlx_typify.register(int)
34+
@mlx_typify.register(float)
35+
def mlx_typify_python_scalar(data, **kwargs):
36+
return mx.array(data)
37+
38+
39+
@mlx_typify.register(bool)
40+
@mlx_typify.register(np.bool_)
41+
def mlx_typify_bool(data, **kwargs):
42+
return bool(data)
43+
44+
45+
@mlx_typify.register(np.integer)
46+
@mlx_typify.register(np.floating)
47+
@mlx_typify.register(np.complexfloating)
48+
def mlx_typify_numpy_scalar(data, **kwargs):
49+
return mx.array(data)
50+
51+
52+
@singledispatch
53+
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
54+
"""Create a MLX compatible function from an PyTensor `Op`."""
55+
raise NotImplementedError(
56+
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
57+
)
58+
59+
60+
@mlx_funcify.register(FunctionGraph)
61+
def mlx_funcify_FunctionGraph(
62+
fgraph,
63+
node=None,
64+
fgraph_name="mlx_funcified_fgraph",
65+
conversion_func=mlx_funcify,
66+
**kwargs,
67+
):
68+
built_kwargs = {"conversion_func": conversion_func, **kwargs}
69+
return fgraph_to_python(
70+
fgraph,
71+
conversion_func,
72+
type_conversion_fn=mlx_typify,
73+
fgraph_name=fgraph_name,
74+
**built_kwargs,
75+
)
76+
77+
78+
@mlx_funcify.register(DeepCopyOp)
79+
def mlx_funcify_DeepCopyOp(op, **kwargs):
80+
def deepcopyop(x):
81+
return deepcopy(x)
82+
83+
return deepcopyop
84+
85+
86+
@mlx_funcify.register(Assert)
87+
@mlx_funcify.register(CheckAndRaise)
88+
def mlx_funcify_CheckAndRaise(op, node, **kwargs):
89+
conds = node.inputs[1:]
90+
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
91+
raise op.exc_type(op.msg)
92+
93+
warnings.warn(
94+
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
95+
stacklevel=2,
96+
)
97+
98+
def assert_fn(x, *inputs):
99+
return x
100+
101+
return assert_fn
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch import mlx_funcify
4+
from pytensor.tensor.blockwise import Blockwise
5+
6+
7+
@mlx_funcify.register(Blockwise)
8+
def funcify_Blockwise(op: Blockwise, node, **kwargs):
9+
# 2) Otherwise, get the core python function for this Blockwise
10+
core_node = op._create_dummy_core_node(node.inputs)
11+
core_f = mlx_funcify(op.core_op, core_node)
12+
13+
# 3) Determine how many inputs correspond to batch dimensions
14+
n_batch = op.batch_ndim(node)
15+
16+
# 4) Handle case where no vectorization is needed
17+
if n_batch == 0:
18+
return core_f
19+
20+
# 5) Vectorize using mx.vmap over any batched inputs
21+
in_axes: list[int | None] = []
22+
for inp, sig in zip(node.inputs, op.inputs_sig):
23+
batch_ndim = inp.type.ndim - len(sig)
24+
if batch_ndim == 0:
25+
in_axes.append(None)
26+
continue
27+
28+
batch_bcast = inp.type.broadcastable[:batch_ndim]
29+
# If all batch dims are broadcastable (size 1), treat input as static
30+
in_axes.append(0 if not all(batch_bcast) else None)
31+
32+
if not any(axis == 0 for axis in in_axes):
33+
return core_f
34+
35+
return mx.vmap(core_f, in_axes=tuple(in_axes))

0 commit comments

Comments
 (0)