Skip to content

Commit d47de98

Browse files
cetagostinijessegrabowski
authored andcommitted
Working with simple model
1 parent 8300fd4 commit d47de98

File tree

5 files changed

+329
-9
lines changed

5 files changed

+329
-9
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,25 @@ def eye(*args):
115115
return eye
116116

117117

118-
def convert_dtype_to_mlx(dtype_str):
118+
def convert_dtype_to_mlx(dtype_str, auto_cast_unsupported=True):
119119
"""Convert PyTensor dtype strings to MLX dtype objects.
120120
121121
MLX expects dtype objects rather than string literals for type conversion.
122122
This function maps common dtype strings to their MLX equivalents.
123+
124+
Parameters
125+
----------
126+
dtype_str : str or MLX dtype
127+
The dtype to convert
128+
auto_cast_unsupported : bool
129+
If True, automatically cast unsupported dtypes to supported ones with warnings
130+
131+
Returns
132+
-------
133+
MLX dtype object
123134
"""
135+
import warnings
136+
124137
if isinstance(dtype_str, str):
125138
if dtype_str == "bool":
126139
return mx.bool_
@@ -145,13 +158,35 @@ def convert_dtype_to_mlx(dtype_str):
145158
elif dtype_str == "float32":
146159
return mx.float32
147160
elif dtype_str == "float64":
148-
return mx.float64
161+
if auto_cast_unsupported:
162+
warnings.warn(
163+
"MLX does not support float64 on GPU. Automatically casting to float32. "
164+
"This may result in reduced precision. To avoid this warning, "
165+
"explicitly use float32 in your code or set floatX='float32' in PyTensor config.",
166+
UserWarning,
167+
stacklevel=3,
168+
)
169+
return mx.float32
170+
else:
171+
return mx.float64
149172
elif dtype_str == "bfloat16":
150173
return mx.bfloat16
151174
elif dtype_str == "complex64":
152175
return mx.complex64
153176
elif dtype_str == "complex128":
154-
return mx.complex128
177+
if auto_cast_unsupported:
178+
warnings.warn(
179+
"MLX does not support complex128. Automatically casting to complex64. "
180+
"This may result in reduced precision. To avoid this warning, "
181+
"explicitly use complex64 in your code.",
182+
UserWarning,
183+
stacklevel=3,
184+
)
185+
return mx.complex64
186+
else:
187+
# Return the original even though it might fail
188+
# This allows users to opt out of auto-casting if needed
189+
return mx.complex64 # MLX doesn't have complex128, so fallback
155190
# Return as is if it's already an MLX dtype or not a recognized string
156191
return dtype_str
157192

@@ -212,7 +247,31 @@ def allocempty(*shape):
212247
@mlx_funcify.register(Alloc)
213248
def mlx_funcify_Alloc(op, node, **kwargs):
214249
def alloc(x, *shape):
215-
res = mx.broadcast_to(x, shape)
216-
return res
250+
try:
251+
# Convert shape elements to Python ints for MLX compatibility
252+
# MLX requires shape dimensions to be Python integers, not MLX arrays
253+
shape_ints = tuple(
254+
int(s.item()) if hasattr(s, "item") else int(s) for s in shape
255+
)
256+
return mx.broadcast_to(x, shape_ints)
257+
except ValueError as e:
258+
if (
259+
"[eval] Attempting to eval an array during function transformations"
260+
in str(e)
261+
):
262+
# This is the MLX compilation limitation - provide helpful error
263+
raise ValueError(
264+
"MLX compilation limitation: Alloc operations with dynamic shapes "
265+
"cannot be used inside compiled functions. This is because MLX "
266+
"compilation forbids evaluating arrays to extract shape values. "
267+
"\n\nWorkarounds:"
268+
"\n1. Avoid using Alloc with dynamic shapes in compiled contexts"
269+
"\n2. Use static shapes when possible"
270+
"\n3. Move Alloc operations outside compiled functions"
271+
"\n\nOriginal error: " + str(e)
272+
) from e
273+
else:
274+
# Re-raise other ValueError exceptions
275+
raise
217276

218277
return alloc

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,25 @@ def softplus(x):
149149
def mlx_funcify_Cast(op, **kwargs):
150150
def cast(x):
151151
dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype)
152-
return x.astype(dtype)
152+
try:
153+
return x.astype(dtype)
154+
except ValueError as e:
155+
if "is not supported on the GPU" in str(e):
156+
# MLX GPU limitation - try auto-casting with warning
157+
import warnings
158+
159+
warnings.warn(
160+
f"MLX GPU limitation: {e}. Attempting automatic fallback casting.",
161+
UserWarning,
162+
stacklevel=2,
163+
)
164+
# Get the auto-cast version
165+
fallback_dtype = convert_dtype_to_mlx(
166+
op.scalar_op.o_type.dtype, auto_cast_unsupported=True
167+
)
168+
return x.astype(fallback_dtype)
169+
else:
170+
# Re-raise other ValueError exceptions
171+
raise
153172

154173
return cast

pytensor/link/mlx/dispatch/math.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,26 @@ def minimum(x, y):
303303
def _(scalar_op):
304304
def cast(x):
305305
dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype)
306-
return x.astype(dtype)
306+
try:
307+
return x.astype(dtype)
308+
except ValueError as e:
309+
if "is not supported on the GPU" in str(e):
310+
# MLX GPU limitation - try auto-casting with warning
311+
import warnings
312+
313+
warnings.warn(
314+
f"MLX GPU limitation: {e}. Attempting automatic fallback casting.",
315+
UserWarning,
316+
stacklevel=2,
317+
)
318+
# Get the auto-cast version
319+
fallback_dtype = convert_dtype_to_mlx(
320+
scalar_op.o_type.dtype, auto_cast_unsupported=True
321+
)
322+
return x.astype(fallback_dtype)
323+
else:
324+
# Re-raise other ValueError exceptions
325+
raise
307326

308327
return cast
309328

pytensor/link/mlx/linker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
class MLXLinker(JITLinker):
55
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
66

7-
def __init__(self, *args, **kwargs):
7+
def __init__(self, use_compile=True, *args, **kwargs):
88
super().__init__(*args, **kwargs)
99
self.gen_functors = []
10+
self.use_compile = use_compile
1011

1112
def fgraph_convert(self, fgraph, **kwargs):
1213
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.
@@ -33,6 +34,13 @@ def jit_compile(self, fn):
3334

3435
from pytensor.link.mlx.dispatch import mlx_typify
3536

37+
if not self.use_compile:
38+
# Skip compilation and just return the function with MLX typification
39+
def fn_no_compile(*inputs):
40+
return fn(*(mlx_typify(inp) for inp in inputs))
41+
42+
return fn_no_compile
43+
3644
inner_fn = mx.compile(fn)
3745

3846
def fn(*inputs, inner_fn=inner_fn):

0 commit comments

Comments
 (0)