1+ """
2+ pytensor/link/mlx/dispatch/basic.py
3+ -----------------------------------
4+
5+ First‑cut MLX translations for the most common tensor Ops.
6+
7+ The structure intentionally follows pytensor's JAX dispatcher so that
8+ once these kernels stabilise they can be optimised further (e.g. fusing
9+ element‑wise graphs, adding in‑place updates, RNG thinning, etc.).
10+ """
11+ from __future__ import annotations
12+
13+ import warnings
14+ import numpy as np
15+
16+ import mlx .core as mx # MLX
17+ from pytensor .link .mlx .dispatch .basic import mlx_funcify # MLX
18+
19+ from pytensor .tensor import get_vector_length
20+ from pytensor .tensor .basic import (
21+ Join , Split , ExtractDiag , Eye , MakeVector ,
22+ ScalarFromTensor , TensorFromScalar , Tri ,
23+ get_scalar_constant_value ,
24+ )
25+ from pytensor .tensor .exceptions import NotScalarConstantError
26+
27+
28+ # ------------------------------------------------------------------
29+ # Join
30+ # ------------------------------------------------------------------
31+ @mlx_funcify .register (Join ) # MLX
32+ def mlx_funcify_Join (op , ** kwargs ):
33+ def join (axis , * tensors ):
34+ view = op .view
35+ if (view != - 1 ) and all (
36+ tensors [i ].shape [axis ] == 0 # MLX
37+ for i in list (range (view )) + list (range (view + 1 , len (tensors )))
38+ ):
39+ return tensors [view ]
40+
41+ return mx .concatenate (tensors , axis = axis ) # MLX
42+
43+ return join
44+
45+
46+ # ------------------------------------------------------------------
47+ # Split
48+ # ------------------------------------------------------------------
49+ @mlx_funcify .register (Split ) # MLX
50+ def mlx_funcify_Split (op : Split , node , ** kwargs ):
51+ _ , axis_sym , splits_sym = node .inputs
52+
53+ try :
54+ constant_axis = get_scalar_constant_value (axis_sym )
55+ except NotScalarConstantError :
56+ constant_axis = None
57+ warnings .warn (
58+ "Split node does not have a constant axis. MLX implementation may fail."
59+ )
60+
61+ try :
62+ constant_splits = np .array (
63+ [get_scalar_constant_value (splits_sym [i ])
64+ for i in range (get_vector_length (splits_sym ))]
65+ )
66+ except (ValueError , NotScalarConstantError ):
67+ constant_splits = None
68+ warnings .warn (
69+ "Split node does not have constant split positions. MLX implementation may fail."
70+ )
71+
72+ def split (x , axis , splits ):
73+ # Resolve constants (avoids tracing extra ops)
74+ if constant_axis is not None :
75+ axis = int (constant_axis )
76+
77+ if constant_splits is not None :
78+ splits = constant_splits
79+ cumsum_splits = np .cumsum (splits [:- 1 ])
80+ else :
81+ # dynamic ‑– keep in graph
82+ splits_arr = mx .array (splits ) # MLX
83+ cumsum_splits = mx .cumsum (splits_arr [:- 1 ]).tolist () # python list for mx.split
84+
85+ if len (splits ) != op .len_splits :
86+ raise ValueError ("Length of 'splits' is not equal to n_splits" )
87+ if np .sum (np .asarray (splits )) != x .shape [axis ]:
88+ raise ValueError ("Split sizes do not sum to the input length on the chosen axis." )
89+ if np .any (np .asarray (splits ) < 0 ):
90+ raise ValueError ("Split sizes cannot be negative." )
91+
92+ return mx .split (x , cumsum_splits , axis = axis ) # MLX
93+
94+ return split
95+
96+
97+ # ------------------------------------------------------------------
98+ # ExtractDiag
99+ # ------------------------------------------------------------------
100+ @mlx_funcify .register (ExtractDiag ) # MLX
101+ def mlx_funcify_ExtractDiag (op , ** kwargs ):
102+ offset , axis1 , axis2 = op .offset , op .axis1 , op .axis2
103+
104+ def extract_diag (x , offset = offset , axis1 = axis1 , axis2 = axis2 ):
105+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 ) # MLX
106+
107+ return extract_diag
108+
109+
110+ # ------------------------------------------------------------------
111+ # Eye
112+ # ------------------------------------------------------------------
113+ @mlx_funcify .register (Eye ) # MLX
114+ def mlx_funcify_Eye (op , ** kwargs ):
115+ dtype = op .dtype
116+
117+ def eye (N , M , k ):
118+ return mx .eye (int (N ), int (M ), int (k ), dtype = dtype ) # MLX
119+
120+ return eye
121+
122+
123+ # ------------------------------------------------------------------
124+ # MakeVector
125+ # ------------------------------------------------------------------
126+ @mlx_funcify .register (MakeVector ) # MLX
127+ def mlx_funcify_MakeVector (op , ** kwargs ):
128+ def makevector (* x ):
129+ return mx .array (x , dtype = op .dtype ) # MLX
130+
131+ return makevector
132+
133+
134+ # ------------------------------------------------------------------
135+ # TensorFromScalar (identity for MLX)
136+ # ------------------------------------------------------------------
137+ @mlx_funcify .register (TensorFromScalar ) # MLX
138+ def mlx_funcify_TensorFromScalar (op , ** kwargs ):
139+ def tensor_from_scalar (x ):
140+ return x # already an MLX array / scalar
141+
142+ return tensor_from_scalar
143+
144+
145+ # ------------------------------------------------------------------
146+ # ScalarFromTensor
147+ # ------------------------------------------------------------------
148+ @mlx_funcify .register (ScalarFromTensor ) # MLX
149+ def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
150+ def scalar_from_tensor (x ):
151+ return mx .array (x ).reshape (- 1 )[0 ] # MLX
152+
153+ return scalar_from_tensor
154+
155+
156+ # ------------------------------------------------------------------
157+ # Tri
158+ # ------------------------------------------------------------------
159+ @mlx_funcify .register (Tri ) # MLX
160+ def mlx_funcify_Tri (op , node , ** kwargs ):
161+ # node.inputs -> N, M, k
162+ const_args = [getattr (inp , "data" , None ) for inp in node .inputs ]
163+
164+ def tri (* args ):
165+ # Replace args with compile‑time constants when available
166+ args = [
167+ arg if const_a is None else const_a
168+ for arg , const_a in zip (args , const_args , strict = True )
169+ ]
170+ return mx .tri (* args , dtype = op .dtype ) # MLX
171+
172+ return tri
173+
174+ ## Change the code to use the mlx functions
0 commit comments