@@ -115,12 +115,25 @@ def eye(*args):
115
115
return eye
116
116
117
117
118
- def convert_dtype_to_mlx (dtype_str ):
118
+ def convert_dtype_to_mlx (dtype_str , auto_cast_unsupported = True ):
119
119
"""Convert PyTensor dtype strings to MLX dtype objects.
120
120
121
121
MLX expects dtype objects rather than string literals for type conversion.
122
122
This function maps common dtype strings to their MLX equivalents.
123
+
124
+ Parameters
125
+ ----------
126
+ dtype_str : str or MLX dtype
127
+ The dtype to convert
128
+ auto_cast_unsupported : bool
129
+ If True, automatically cast unsupported dtypes to supported ones with warnings
130
+
131
+ Returns
132
+ -------
133
+ MLX dtype object
123
134
"""
135
+ import warnings
136
+
124
137
if isinstance (dtype_str , str ):
125
138
if dtype_str == "bool" :
126
139
return mx .bool_
@@ -145,13 +158,35 @@ def convert_dtype_to_mlx(dtype_str):
145
158
elif dtype_str == "float32" :
146
159
return mx .float32
147
160
elif dtype_str == "float64" :
148
- return mx .float64
161
+ if auto_cast_unsupported :
162
+ warnings .warn (
163
+ "MLX does not support float64 on GPU. Automatically casting to float32. "
164
+ "This may result in reduced precision. To avoid this warning, "
165
+ "explicitly use float32 in your code or set floatX='float32' in PyTensor config." ,
166
+ UserWarning ,
167
+ stacklevel = 3 ,
168
+ )
169
+ return mx .float32
170
+ else :
171
+ return mx .float64
149
172
elif dtype_str == "bfloat16" :
150
173
return mx .bfloat16
151
174
elif dtype_str == "complex64" :
152
175
return mx .complex64
153
176
elif dtype_str == "complex128" :
154
- return mx .complex128
177
+ if auto_cast_unsupported :
178
+ warnings .warn (
179
+ "MLX does not support complex128. Automatically casting to complex64. "
180
+ "This may result in reduced precision. To avoid this warning, "
181
+ "explicitly use complex64 in your code." ,
182
+ UserWarning ,
183
+ stacklevel = 3 ,
184
+ )
185
+ return mx .complex64
186
+ else :
187
+ # Return the original even though it might fail
188
+ # This allows users to opt out of auto-casting if needed
189
+ return mx .complex64 # MLX doesn't have complex128, so fallback
155
190
# Return as is if it's already an MLX dtype or not a recognized string
156
191
return dtype_str
157
192
@@ -212,7 +247,31 @@ def allocempty(*shape):
212
247
@mlx_funcify .register (Alloc )
213
248
def mlx_funcify_Alloc (op , node , ** kwargs ):
214
249
def alloc (x , * shape ):
215
- res = mx .broadcast_to (x , shape )
216
- return res
250
+ try :
251
+ # Convert shape elements to Python ints for MLX compatibility
252
+ # MLX requires shape dimensions to be Python integers, not MLX arrays
253
+ shape_ints = tuple (
254
+ int (s .item ()) if hasattr (s , "item" ) else int (s ) for s in shape
255
+ )
256
+ return mx .broadcast_to (x , shape_ints )
257
+ except ValueError as e :
258
+ if (
259
+ "[eval] Attempting to eval an array during function transformations"
260
+ in str (e )
261
+ ):
262
+ # This is the MLX compilation limitation - provide helpful error
263
+ raise ValueError (
264
+ "MLX compilation limitation: Alloc operations with dynamic shapes "
265
+ "cannot be used inside compiled functions. This is because MLX "
266
+ "compilation forbids evaluating arrays to extract shape values. "
267
+ "\n \n Workarounds:"
268
+ "\n 1. Avoid using Alloc with dynamic shapes in compiled contexts"
269
+ "\n 2. Use static shapes when possible"
270
+ "\n 3. Move Alloc operations outside compiled functions"
271
+ "\n \n Original error: " + str (e )
272
+ ) from e
273
+ else :
274
+ # Re-raise other ValueError exceptions
275
+ raise
217
276
218
277
return alloc
0 commit comments