Skip to content

Commit d25f214

Browse files
committed
mlx poc
1 parent e299023 commit d25f214

File tree

6 files changed

+218
-8
lines changed

6 files changed

+218
-8
lines changed

pytensor/compile/mode.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.link.basic import Linker, PerformLinker
2828
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2929
from pytensor.link.jax.linker import JAXLinker
30+
from pytensor.link.mlx.linker import MLXLinker
3031
from pytensor.link.numba.linker import NumbaLinker
3132
from pytensor.link.pytorch.linker import PytorchLinker
3233
from pytensor.link.vm import VMLinker
@@ -50,6 +51,7 @@
5051
"jax": JAXLinker(),
5152
"pytorch": PytorchLinker(),
5253
"numba": NumbaLinker(),
54+
"mlx": MLXLinker(),
5355
}
5456

5557

@@ -494,13 +496,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
494496
),
495497
)
496498

499+
MLX = Mode(
500+
MLXLinker(),
501+
RewriteDatabaseQuery(
502+
include=["fast_run"],
503+
exclude=[
504+
"cxx_only",
505+
"BlasOpt",
506+
"fusion",
507+
"inplace",
508+
"scan_save_mem_prealloc",
509+
],
510+
),
511+
)
512+
497513

498514
predefined_modes = {
499515
"FAST_COMPILE": FAST_COMPILE,
500516
"FAST_RUN": FAST_RUN,
501517
"JAX": JAX,
502518
"NUMBA": NUMBA,
503519
"PYTORCH": PYTORCH,
520+
"MLX": MLX,
504521
}
505522

506523
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
@@ -585,6 +602,8 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"],
585602
return ("py",)
586603
if isinstance(linker, CLinker):
587604
return ("c",)
605+
if isinstance(linker, MLXLinker):
606+
return ("py",)
588607

589608
if isinstance(linker, VMLinker | OpWiseCLinker):
590609
return ("c", "py") if config.cxx else ("py",)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# isort: off
2+
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
3+
4+
import pytensor.link.mlx.dispatch.math
5+
# isort: on
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from functools import singledispatch
2+
from types import NoneType
3+
4+
import mlx.core as mx
5+
import numpy as np
6+
7+
from pytensor.compile.ops import DeepCopyOp
8+
from pytensor.graph.fg import FunctionGraph
9+
from pytensor.link.utils import fgraph_to_python
10+
11+
12+
@singledispatch
13+
def mlx_typify(data, **kwargs):
14+
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")
15+
16+
17+
@mlx_typify.register(np.ndarray)
18+
@mlx_typify.register(mx.array)
19+
def mlx_typify_tensor(data, dtype=None, **kwargs):
20+
return mx.array(data, dtype=dtype)
21+
22+
23+
@mlx_typify.register(slice)
24+
@mlx_typify.register(NoneType)
25+
@mlx_typify.register(np.number)
26+
def mlx_typify_no_conversion_needed(data, **kwargs):
27+
return data
28+
29+
30+
@singledispatch
31+
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
32+
"""Create a MLX compatible function from an PyTensor `Op`."""
33+
raise NotImplementedError(
34+
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
35+
)
36+
37+
38+
@mlx_funcify.register(FunctionGraph)
39+
def mlx_funcify_FunctionGraph(
40+
fgraph,
41+
node=None,
42+
fgraph_name="mlx_funcified_fgraph",
43+
conversion_func=mlx_funcify,
44+
**kwargs,
45+
):
46+
built_kwargs = {"conversion_func": conversion_func, **kwargs}
47+
return fgraph_to_python(
48+
fgraph,
49+
conversion_func,
50+
type_conversion_fn=mlx_typify,
51+
fgraph_name=fgraph_name,
52+
**built_kwargs,
53+
)
54+
55+
56+
@mlx_funcify.register(DeepCopyOp)
57+
def mlx_funcify_DeepCopyOp(op, **kwargs):
58+
def deepcopyop(x):
59+
return x.copy()
60+
61+
return deepcopyop

pytensor/link/mlx/dispatch/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch import mlx_funcify
4+
from pytensor.tensor.math import Dot
5+
6+
7+
@mlx_funcify.register(Dot)
8+
def mlx_funcify_Dot(op, **kwargs):
9+
def dot(x, y):
10+
return mx.matmul(x, y)
11+
12+
return dot

pytensor/link/mlx/linker.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from pytensor.link.basic import JITLinker
2+
from pytensor.link.utils import unique_name_generator
3+
4+
5+
class MLXLinker(JITLinker):
6+
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
7+
8+
def __init__(self, *args, **kwargs):
9+
super().__init__(*args, **kwargs)
10+
self.gen_functors = []
11+
12+
def fgraph_convert(
13+
self,
14+
fgraph,
15+
order,
16+
input_storage,
17+
output_storage,
18+
storage_map,
19+
**kwargs,
20+
):
21+
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.
22+
23+
Parameters
24+
----------
25+
fgraph : FunctionGraph
26+
The function graph to convert
27+
order : list
28+
The order in which to compute the nodes
29+
input_storage : list
30+
Storage for the input variables
31+
output_storage : list
32+
Storage for the output variables
33+
storage_map : dict
34+
Map from variables to their storage
35+
36+
Returns
37+
-------
38+
callable
39+
An MLX-compatible function
40+
"""
41+
from pytensor.link.mlx.dispatch import mlx_funcify
42+
43+
# We want to have globally unique names
44+
# across the entire pytensor graph, not
45+
# just the subgraph
46+
generator = unique_name_generator(["mlx_linker"])
47+
48+
# Ensure that torch is aware of the generated
49+
# code so we can compile without graph breaks
50+
def conversion_func_register(*args, **kwargs):
51+
functor = mlx_funcify(*args, **kwargs)
52+
name = kwargs["unique_name"](functor)
53+
self.gen_functors.append((f"_{name}", functor))
54+
return functor
55+
56+
built_kwargs = {
57+
"unique_name": generator,
58+
"conversion_func": conversion_func_register,
59+
**kwargs,
60+
}
61+
return mlx_funcify(
62+
fgraph,
63+
input_storage=input_storage,
64+
storage_map=storage_map,
65+
**built_kwargs,
66+
)
67+
68+
def jit_compile(self, fn):
69+
"""JIT compile an MLX function.
70+
71+
Parameters
72+
----------
73+
fn : callable
74+
The function to compile
75+
76+
Returns
77+
-------
78+
callable
79+
The compiled function
80+
"""
81+
import mlx.core as mx
82+
83+
return mx.compile(fn)
84+
85+
def create_thunk_inputs(self, storage_map):
86+
"""Create inputs for the MLX thunk.
87+
88+
Parameters
89+
----------
90+
storage_map : dict
91+
Map from variables to their storage
92+
93+
Returns
94+
-------
95+
list
96+
The inputs for the thunk
97+
"""
98+
from numpy.random import Generator, RandomState
99+
100+
from pytensor.link.mlx.dispatch import mlx_typify
101+
102+
thunk_inputs = []
103+
for n in self.fgraph.inputs:
104+
sinput = storage_map[n]
105+
# Handle random number generators specially
106+
if isinstance(sinput[0], RandomState | Generator):
107+
new_value = mlx_typify(
108+
sinput[0], dtype=getattr(sinput[0], "dtype", None)
109+
)
110+
sinput[0] = new_value
111+
thunk_inputs.append(sinput)
112+
113+
return thunk_inputs

pytensor/link/pytorch/linker.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def conversion_func_register(*args, **kwargs):
3131
**kwargs,
3232
}
3333
return pytorch_funcify(
34-
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
34+
fgraph,
35+
input_storage=input_storage,
36+
storage_map=storage_map,
37+
**built_kwargs,
3538
)
3639

3740
def jit_compile(self, fn):
38-
import torch
41+
import mlx.core as mx
3942

40-
# flag that tend to help our graphs
41-
torch._dynamo.config.capture_dynamic_output_shape_ops = True
42-
43-
from pytensor.link.pytorch.dispatch import pytorch_typify
43+
from pytensor.link.mlx.dispatch import mlx_typify
4444

4545
class wrapper:
4646
"""
@@ -54,7 +54,7 @@ class wrapper:
5454
"""
5555

5656
def __init__(self, fn, gen_functors):
57-
self.fn = torch.compile(fn)
57+
self.fn = mx.compile(fn)
5858
self.gen_functors = gen_functors.copy()
5959

6060
def __call__(self, *inputs, **kwargs):
@@ -65,7 +65,7 @@ def __call__(self, *inputs, **kwargs):
6565
setattr(pytensor.link.utils, n[1:], fn)
6666

6767
# Torch does not accept numpy inputs and may return GPU objects
68-
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
68+
outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs)
6969

7070
# unset attrs
7171
for n, _ in self.gen_functors:

0 commit comments

Comments
 (0)