Skip to content

Commit 4d5b34b

Browse files
committed
Pushing code
1 parent bc98e09 commit 4d5b34b

File tree

5 files changed

+88
-33
lines changed

5 files changed

+88
-33
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77
import pytensor.link.mlx.dispatch.shape
88
import pytensor.link.mlx.dispatch.subtensor
99
import pytensor.link.mlx.dispatch.core
10+
import pytensor.link.mlx.dispatch.signal
11+
import pytensor.link.mlx.dispatch.signal.conv
1012
# isort: on

pytensor/link/mlx/dispatch/core.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,33 @@
22
pytensor/link/mlx/dispatch/basic.py
33
-----------------------------------
44
5-
Firstcut MLX translations for the most common tensor Ops.
5+
First-cut MLX translations for the most common tensor Ops.
66
77
The structure intentionally follows pytensor's JAX dispatcher so that
88
once these kernels stabilise they can be optimised further (e.g. fusing
9-
elementwise graphs, adding inplace updates, RNG thinning, etc.).
9+
element-wise graphs, adding in-place updates, RNG thinning, etc.).
1010
"""
11+
1112
from __future__ import annotations
1213

1314
import warnings
14-
import numpy as np
1515

16-
import mlx.core as mx MLX
17-
from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX
16+
import mlx.core as mx # MLX
17+
import numpy as np
1818

19+
from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX
1920
from pytensor.tensor import get_vector_length
2021
from pytensor.tensor.basic import (
21-
Join, Split, ExtractDiag, Eye, MakeVector,
22-
ScalarFromTensor, TensorFromScalar, Tri,
22+
Alloc,
23+
AllocEmpty,
24+
ExtractDiag,
25+
Eye,
26+
Join,
27+
MakeVector,
28+
ScalarFromTensor,
29+
Split,
30+
TensorFromScalar,
31+
Tri,
2332
get_scalar_constant_value,
2433
)
2534
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -28,25 +37,25 @@
2837
# ------------------------------------------------------------------
2938
# Join
3039
# ------------------------------------------------------------------
31-
@mlx_funcify.register(Join) MLX
40+
@mlx_funcify.register(Join) # MLX
3241
def mlx_funcify_Join(op, **kwargs):
3342
def join(axis, *tensors):
3443
view = op.view
3544
if (view != -1) and all(
36-
tensors[i].shape[axis] == 0 MLX
45+
tensors[i].shape[axis] == 0 # MLX
3746
for i in list(range(view)) + list(range(view + 1, len(tensors)))
3847
):
3948
return tensors[view]
4049

41-
return mx.concatenate(tensors, axis=axis) MLX
50+
return mx.concatenate(tensors, axis=axis) # MLX
4251

4352
return join
4453

4554

4655
# ------------------------------------------------------------------
4756
# Split
4857
# ------------------------------------------------------------------
49-
@mlx_funcify.register(Split) MLX
58+
@mlx_funcify.register(Split) # MLX
5059
def mlx_funcify_Split(op: Split, node, **kwargs):
5160
_, axis_sym, splits_sym = node.inputs
5261

@@ -60,8 +69,10 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
6069

6170
try:
6271
constant_splits = np.array(
63-
[get_scalar_constant_value(splits_sym[i])
64-
for i in range(get_vector_length(splits_sym))]
72+
[
73+
get_scalar_constant_value(splits_sym[i])
74+
for i in range(get_vector_length(splits_sym))
75+
]
6576
)
6677
except (ValueError, NotScalarConstantError):
6778
constant_splits = None
@@ -78,97 +89,117 @@ def split(x, axis, splits):
7889
splits = constant_splits
7990
cumsum_splits = np.cumsum(splits[:-1])
8091
else:
81-
# dynamic ‑– keep in graph
82-
splits_arr = mx.array(splits) # MLX
83-
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split
92+
# dynamic - keep in graph
93+
splits_arr = mx.array(splits) # MLX
94+
cumsum_splits = mx.cumsum(
95+
splits_arr[:-1]
96+
).tolist() # python list for mx.split
8497

8598
if len(splits) != op.len_splits:
8699
raise ValueError("Length of 'splits' is not equal to n_splits")
87100
if np.sum(np.asarray(splits)) != x.shape[axis]:
88-
raise ValueError("Split sizes do not sum to the input length on the chosen axis.")
101+
raise ValueError(
102+
"Split sizes do not sum to the input length on the chosen axis."
103+
)
89104
if np.any(np.asarray(splits) < 0):
90105
raise ValueError("Split sizes cannot be negative.")
91106

92-
return mx.split(x, cumsum_splits, axis=axis) MLX
107+
return mx.split(x, cumsum_splits, axis=axis) # MLX
93108

94109
return split
95110

96111

97112
# ------------------------------------------------------------------
98113
# ExtractDiag
99114
# ------------------------------------------------------------------
100-
@mlx_funcify.register(ExtractDiag) MLX
115+
@mlx_funcify.register(ExtractDiag) # MLX
101116
def mlx_funcify_ExtractDiag(op, **kwargs):
102117
offset, axis1, axis2 = op.offset, op.axis1, op.axis2
103118

104119
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
105-
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX
120+
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX
106121

107122
return extract_diag
108123

109124

110125
# ------------------------------------------------------------------
111126
# Eye
112127
# ------------------------------------------------------------------
113-
@mlx_funcify.register(Eye) MLX
128+
@mlx_funcify.register(Eye) # MLX
114129
def mlx_funcify_Eye(op, **kwargs):
115130
dtype = op.dtype
116131

117132
def eye(N, M, k):
118-
return mx.eye(int(N), int(M), int(k), dtype=dtype) MLX
133+
return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX
119134

120135
return eye
121136

122137

123138
# ------------------------------------------------------------------
124139
# MakeVector
125140
# ------------------------------------------------------------------
126-
@mlx_funcify.register(MakeVector) MLX
141+
@mlx_funcify.register(MakeVector) # MLX
127142
def mlx_funcify_MakeVector(op, **kwargs):
128143
def makevector(*x):
129-
return mx.array(x, dtype=op.dtype) MLX
144+
return mx.array(x, dtype=op.dtype) # MLX
130145

131146
return makevector
132147

133148

134149
# ------------------------------------------------------------------
135150
# TensorFromScalar (identity for MLX)
136151
# ------------------------------------------------------------------
137-
@mlx_funcify.register(TensorFromScalar) MLX
152+
@mlx_funcify.register(TensorFromScalar) # MLX
138153
def mlx_funcify_TensorFromScalar(op, **kwargs):
139154
def tensor_from_scalar(x):
140-
return x # already an MLX array / scalar
155+
return x # already an MLX array / scalar
141156

142157
return tensor_from_scalar
143158

144159

145160
# ------------------------------------------------------------------
146161
# ScalarFromTensor
147162
# ------------------------------------------------------------------
148-
@mlx_funcify.register(ScalarFromTensor) MLX
163+
@mlx_funcify.register(ScalarFromTensor) # MLX
149164
def mlx_funcify_ScalarFromTensor(op, **kwargs):
150165
def scalar_from_tensor(x):
151-
return mx.array(x).reshape(-1)[0] MLX
166+
return mx.array(x).reshape(-1)[0] # MLX
152167

153168
return scalar_from_tensor
154169

155170

156171
# ------------------------------------------------------------------
157172
# Tri
158173
# ------------------------------------------------------------------
159-
@mlx_funcify.register(Tri) MLX
174+
@mlx_funcify.register(Tri) # MLX
160175
def mlx_funcify_Tri(op, node, **kwargs):
161176
# node.inputs -> N, M, k
162177
const_args = [getattr(inp, "data", None) for inp in node.inputs]
163178

164179
def tri(*args):
165-
# Replace args with compiletime constants when available
180+
# Replace args with compile-time constants when available
166181
args = [
167182
arg if const_a is None else const_a
168183
for arg, const_a in zip(args, const_args, strict=True)
169184
]
170-
return mx.tri(*args, dtype=op.dtype) MLX
185+
return mx.tri(*args, dtype=op.dtype) # MLX
171186

172187
return tri
173188

174-
## Change the code to use the mlx functions
189+
190+
@mlx_funcify.register(AllocEmpty)
191+
def mlx_funcify_AllocEmpty(op, **kwargs):
192+
def allocempty(*shape):
193+
return mx.zeros(shape, dtype=op.dtype)
194+
195+
return allocempty
196+
197+
198+
@mlx_funcify.register(Alloc)
199+
def mlx_funcify_Alloc(op, node, **kwargs):
200+
def alloc(x, *shape):
201+
res = mx.broadcast_to(x, shape)
202+
Alloc._check_runtime_broadcast(node, mx.array(x), res.shape)
203+
return res
204+
205+
return alloc

pytensor/link/mlx/dispatch/shape.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2-
from pytensor.tensor.shape import SpecifyShape
2+
from pytensor.tensor.shape import Shape_i, SpecifyShape
33

44

55
@mlx_funcify.register(SpecifyShape)
@@ -14,3 +14,11 @@ def specifyshape(x, *shape):
1414
return x
1515

1616
return specifyshape
17+
18+
19+
@mlx_funcify.register(Shape_i)
20+
def mlx_funcify_Shape_i(op, node, **kwargs):
21+
def shape_i(x, i):
22+
return x.shape[op.i]
23+
24+
return shape_i

pytensor/link/mlx/dispatch/signal/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from pytensor.link.mlx.dispatch import mlx_funcify
2+
from pytensor.tensor.signal.conv import Conv1d
3+
4+
import mlx.core as mx
5+
6+
7+
@mlx_funcify.register(Conv1d)
8+
def mlx_funcify_Conv1d(op, node, **kwargs):
9+
mode = op.mode
10+
11+
def conv1d(data, kernel):
12+
return mx.convolve(data, kernel, mode=mode)
13+
14+
return conv1d

0 commit comments

Comments
 (0)