22pytensor/link/mlx/dispatch/basic.py
33-----------------------------------
44
5- First‑ cut MLX translations for the most common tensor Ops.
5+ First- cut MLX translations for the most common tensor Ops.
66
77The structure intentionally follows pytensor's JAX dispatcher so that
88once these kernels stabilise they can be optimised further (e.g. fusing
9- element‑ wise graphs, adding in‑ place updates, RNG thinning, etc.).
9+ element- wise graphs, adding in- place updates, RNG thinning, etc.).
1010"""
11+
1112from __future__ import annotations
1213
1314import warnings
14- import numpy as np
1515
16- import mlx .core as mx # MLX
17- from pytensor . link . mlx . dispatch . basic import mlx_funcify # MLX
16+ import mlx .core as mx # MLX
17+ import numpy as np
1818
19+ from pytensor .link .mlx .dispatch .basic import mlx_funcify # MLX
1920from pytensor .tensor import get_vector_length
2021from pytensor .tensor .basic import (
21- Join , Split , ExtractDiag , Eye , MakeVector ,
22- ScalarFromTensor , TensorFromScalar , Tri ,
22+ Alloc ,
23+ AllocEmpty ,
24+ ExtractDiag ,
25+ Eye ,
26+ Join ,
27+ MakeVector ,
28+ ScalarFromTensor ,
29+ Split ,
30+ TensorFromScalar ,
31+ Tri ,
2332 get_scalar_constant_value ,
2433)
2534from pytensor .tensor .exceptions import NotScalarConstantError
2837# ------------------------------------------------------------------
2938# Join
3039# ------------------------------------------------------------------
31- @mlx_funcify .register (Join ) # MLX
40+ @mlx_funcify .register (Join ) # MLX
3241def mlx_funcify_Join (op , ** kwargs ):
3342 def join (axis , * tensors ):
3443 view = op .view
3544 if (view != - 1 ) and all (
36- tensors [i ].shape [axis ] == 0 # MLX
45+ tensors [i ].shape [axis ] == 0 # MLX
3746 for i in list (range (view )) + list (range (view + 1 , len (tensors )))
3847 ):
3948 return tensors [view ]
4049
41- return mx .concatenate (tensors , axis = axis ) # MLX
50+ return mx .concatenate (tensors , axis = axis ) # MLX
4251
4352 return join
4453
4554
4655# ------------------------------------------------------------------
4756# Split
4857# ------------------------------------------------------------------
49- @mlx_funcify .register (Split ) # MLX
58+ @mlx_funcify .register (Split ) # MLX
5059def mlx_funcify_Split (op : Split , node , ** kwargs ):
5160 _ , axis_sym , splits_sym = node .inputs
5261
@@ -60,8 +69,10 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
6069
6170 try :
6271 constant_splits = np .array (
63- [get_scalar_constant_value (splits_sym [i ])
64- for i in range (get_vector_length (splits_sym ))]
72+ [
73+ get_scalar_constant_value (splits_sym [i ])
74+ for i in range (get_vector_length (splits_sym ))
75+ ]
6576 )
6677 except (ValueError , NotScalarConstantError ):
6778 constant_splits = None
@@ -78,97 +89,117 @@ def split(x, axis, splits):
7889 splits = constant_splits
7990 cumsum_splits = np .cumsum (splits [:- 1 ])
8091 else :
81- # dynamic ‑– keep in graph
82- splits_arr = mx .array (splits ) # MLX
83- cumsum_splits = mx .cumsum (splits_arr [:- 1 ]).tolist () # python list for mx.split
92+ # dynamic - keep in graph
93+ splits_arr = mx .array (splits ) # MLX
94+ cumsum_splits = mx .cumsum (
95+ splits_arr [:- 1 ]
96+ ).tolist () # python list for mx.split
8497
8598 if len (splits ) != op .len_splits :
8699 raise ValueError ("Length of 'splits' is not equal to n_splits" )
87100 if np .sum (np .asarray (splits )) != x .shape [axis ]:
88- raise ValueError ("Split sizes do not sum to the input length on the chosen axis." )
101+ raise ValueError (
102+ "Split sizes do not sum to the input length on the chosen axis."
103+ )
89104 if np .any (np .asarray (splits ) < 0 ):
90105 raise ValueError ("Split sizes cannot be negative." )
91106
92- return mx .split (x , cumsum_splits , axis = axis ) # MLX
107+ return mx .split (x , cumsum_splits , axis = axis ) # MLX
93108
94109 return split
95110
96111
97112# ------------------------------------------------------------------
98113# ExtractDiag
99114# ------------------------------------------------------------------
100- @mlx_funcify .register (ExtractDiag ) # MLX
115+ @mlx_funcify .register (ExtractDiag ) # MLX
101116def mlx_funcify_ExtractDiag (op , ** kwargs ):
102117 offset , axis1 , axis2 = op .offset , op .axis1 , op .axis2
103118
104119 def extract_diag (x , offset = offset , axis1 = axis1 , axis2 = axis2 ):
105- return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 ) # MLX
120+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 ) # MLX
106121
107122 return extract_diag
108123
109124
110125# ------------------------------------------------------------------
111126# Eye
112127# ------------------------------------------------------------------
113- @mlx_funcify .register (Eye ) # MLX
128+ @mlx_funcify .register (Eye ) # MLX
114129def mlx_funcify_Eye (op , ** kwargs ):
115130 dtype = op .dtype
116131
117132 def eye (N , M , k ):
118- return mx .eye (int (N ), int (M ), int (k ), dtype = dtype ) # MLX
133+ return mx .eye (int (N ), int (M ), int (k ), dtype = dtype ) # MLX
119134
120135 return eye
121136
122137
123138# ------------------------------------------------------------------
124139# MakeVector
125140# ------------------------------------------------------------------
126- @mlx_funcify .register (MakeVector ) # MLX
141+ @mlx_funcify .register (MakeVector ) # MLX
127142def mlx_funcify_MakeVector (op , ** kwargs ):
128143 def makevector (* x ):
129- return mx .array (x , dtype = op .dtype ) # MLX
144+ return mx .array (x , dtype = op .dtype ) # MLX
130145
131146 return makevector
132147
133148
134149# ------------------------------------------------------------------
135150# TensorFromScalar (identity for MLX)
136151# ------------------------------------------------------------------
137- @mlx_funcify .register (TensorFromScalar ) # MLX
152+ @mlx_funcify .register (TensorFromScalar ) # MLX
138153def mlx_funcify_TensorFromScalar (op , ** kwargs ):
139154 def tensor_from_scalar (x ):
140- return x # already an MLX array / scalar
155+ return x # already an MLX array / scalar
141156
142157 return tensor_from_scalar
143158
144159
145160# ------------------------------------------------------------------
146161# ScalarFromTensor
147162# ------------------------------------------------------------------
148- @mlx_funcify .register (ScalarFromTensor ) # MLX
163+ @mlx_funcify .register (ScalarFromTensor ) # MLX
149164def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
150165 def scalar_from_tensor (x ):
151- return mx .array (x ).reshape (- 1 )[0 ] # MLX
166+ return mx .array (x ).reshape (- 1 )[0 ] # MLX
152167
153168 return scalar_from_tensor
154169
155170
156171# ------------------------------------------------------------------
157172# Tri
158173# ------------------------------------------------------------------
159- @mlx_funcify .register (Tri ) # MLX
174+ @mlx_funcify .register (Tri ) # MLX
160175def mlx_funcify_Tri (op , node , ** kwargs ):
161176 # node.inputs -> N, M, k
162177 const_args = [getattr (inp , "data" , None ) for inp in node .inputs ]
163178
164179 def tri (* args ):
165- # Replace args with compile‑ time constants when available
180+ # Replace args with compile- time constants when available
166181 args = [
167182 arg if const_a is None else const_a
168183 for arg , const_a in zip (args , const_args , strict = True )
169184 ]
170- return mx .tri (* args , dtype = op .dtype ) # MLX
185+ return mx .tri (* args , dtype = op .dtype ) # MLX
171186
172187 return tri
173188
174- ## Change the code to use the mlx functions
189+
190+ @mlx_funcify .register (AllocEmpty )
191+ def mlx_funcify_AllocEmpty (op , ** kwargs ):
192+ def allocempty (* shape ):
193+ return mx .zeros (shape , dtype = op .dtype )
194+
195+ return allocempty
196+
197+
198+ @mlx_funcify .register (Alloc )
199+ def mlx_funcify_Alloc (op , node , ** kwargs ):
200+ def alloc (x , * shape ):
201+ res = mx .broadcast_to (x , shape )
202+ Alloc ._check_runtime_broadcast (node , mx .array (x ), res .shape )
203+ return res
204+
205+ return alloc
0 commit comments