Skip to content

Commit b438520

Browse files
committed
Feedback from Ricardo
1 parent e02398f commit b438520

File tree

10 files changed

+227
-169
lines changed

10 files changed

+227
-169
lines changed

pytensor/link/mlx/dispatch/basic.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
from pytensor.compile.ops import DeepCopyOp
10+
from pytensor.graph import Constant
1011
from pytensor.graph.fg import FunctionGraph
1112
from pytensor.link.utils import fgraph_to_python
1213
from pytensor.raise_op import Assert, CheckAndRaise
@@ -24,7 +25,6 @@ def mlx_typify_tensor(data, dtype=None, **kwargs):
2425

2526
@mlx_typify.register(slice)
2627
@mlx_typify.register(NoneType)
27-
@mlx_typify.register(np.number)
2828
@mlx_typify.register(mx.array)
2929
def mlx_typify_no_conversion_needed(data, **kwargs):
3030
return data
@@ -36,6 +36,19 @@ def mlx_typify_python_scalar(data, **kwargs):
3636
return mx.array(data)
3737

3838

39+
@mlx_typify.register(bool)
40+
@mlx_typify.register(np.bool_)
41+
def mlx_typify_bool(data, **kwargs):
42+
return bool(data)
43+
44+
45+
@mlx_typify.register(np.integer)
46+
@mlx_typify.register(np.floating)
47+
@mlx_typify.register(np.complexfloating)
48+
def mlx_typify_numpy_scalar(data, **kwargs):
49+
return mx.array(data)
50+
51+
3952
@singledispatch
4053
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
4154
"""Create a MLX compatible function from an PyTensor `Op`."""
@@ -72,9 +85,13 @@ def deepcopyop(x):
7285

7386
@mlx_funcify.register(Assert)
7487
@mlx_funcify.register(CheckAndRaise)
75-
def mlx_funcify_CheckAndRaise(op, **kwargs):
88+
def mlx_funcify_CheckAndRaise(op, node, **kwargs):
89+
conds = node.inputs[1:]
90+
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
91+
raise op.exc_type(op.msg)
92+
7693
warnings.warn(
77-
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
94+
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
7895
stacklevel=2,
7996
)
8097

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 22 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,105 +2,46 @@
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
44
from pytensor.tensor.blockwise import Blockwise
5-
from pytensor.tensor.signal.conv import Convolve1d as Conv1d
6-
7-
8-
def blockwise_conv1d(op, node, **kwargs):
9-
"""
10-
Custom implementation of Blockwise.conv1d for MLX.
11-
"""
12-
13-
def batched_conv1d(
14-
x: mx.array,
15-
kernels: mx.array,
16-
mode: str = op.core_op.mode,
17-
stride: int = 1,
18-
dilation: int = 1,
19-
) -> mx.array:
20-
"""
21-
Apply B separate 1D convolutions (full or valid) to B sequences in parallel.
22-
23-
Parameters
24-
----------
25-
x : array of shape (B, T)
26-
B sequences of length T.
27-
kernels : array of shape (B, K)
28-
B kernels of length K.
29-
mode : {"valid", "full"}
30-
"valid" → no padding, output length = T - K + 1
31-
"full" → zero-pad so output length = T + K - 1
32-
stride : int, convolution stride (default=1)
33-
dilation : int, convolution dilation (default=1)
34-
35-
Returns
36-
-------
37-
out : array of shape (B, L)
38-
where L =
39-
- T - K + 1 if mode="valid"
40-
- T + K - 1 if mode="full"
41-
"""
42-
# --- 1) shape checks ---
43-
B, T = x.shape
44-
Bk, K = kernels.shape
45-
if B != Bk:
46-
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
47-
48-
# --- 2) flip kernels for convolution ---
49-
kernels_flipped = kernels[:, ::-1] # shape (B, K)
50-
51-
# --- 3) decide padding ---
52-
if mode == "valid":
53-
pad = 0
54-
elif mode == "full":
55-
pad = (K - 1) * dilation
56-
else:
57-
raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'")
58-
59-
# --- 4) reshape into MLX conv1d form ---
60-
# input: (N=1, H=T, C_in=B)
61-
x_in = x.T[None, :, :]
62-
63-
# weight: (C_out=B, H_f=K, C_in=1)
64-
w = kernels_flipped[:, :, None]
65-
66-
# --- 5) run grouped conv1d ---
67-
y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B)
68-
# y shape: (1, H_out, B)
69-
70-
# --- 6) return shape (B, H_out) ---
71-
return y[0].T
72-
73-
return batched_conv1d
745

756

767
@mlx_funcify.register(Blockwise)
778
def funcify_Blockwise(op: Blockwise, node, **kwargs):
78-
# 1) If it's a Conv1d Blockwise, use the custom implementation
79-
if isinstance(op.core_op, Conv1d):
80-
return blockwise_conv1d(op, node, **kwargs)
81-
829
# 2) Otherwise, get the core python function for this Blockwise
8310
core_node = op._create_dummy_core_node(node.inputs)
8411
core_f = mlx_funcify(op.core_op, core_node)
8512

8613
# 3) Determine how many inputs correspond to batch dimensions
8714
n_batch = op.batch_ndim(node)
8815

89-
# 4) Build in_axes: map only the first n_batch args, keep the rest static
90-
in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs)))
16+
# 4) Handle case where no vectorization is needed
17+
if n_batch == 0:
18+
19+
def blockwise_fun(*inputs):
20+
return core_f(*inputs)
21+
22+
return blockwise_fun
23+
24+
# 5) Vectorize using mx.vmap over any batched inputs
25+
in_axes = []
26+
for inp, sig in zip(node.inputs, op.inputs_sig):
27+
batch_ndim = inp.type.ndim - len(sig)
28+
if batch_ndim == 0:
29+
in_axes.append(None)
30+
continue
31+
32+
batch_bcast = inp.type.broadcastable[:batch_ndim]
33+
# If all batch dims are broadcastable (size 1), treat input as static
34+
in_axes.append(0 if not all(batch_bcast) else None)
35+
36+
if not any(axis == 0 for axis in in_axes):
9137

92-
# 5) Handle case where no vectorization is needed
93-
if n_batch == 0 or all(axis is None for axis in in_axes):
94-
# No batch dimensions, just return the core function
9538
def blockwise_fun(*inputs):
9639
return core_f(*inputs)
9740

9841
return blockwise_fun
9942

100-
# 6) Vectorize (vmap) with in_axes
101-
blockwise_f = mx.vmap(core_f, in_axes=in_axes)
43+
blockwise_f = mx.vmap(core_f, in_axes=tuple(in_axes))
10244

103-
# 7) Return the mapped function
10445
def blockwise_fun(*inputs):
10546
return blockwise_f(*inputs)
10647

pytensor/link/mlx/dispatch/core.py

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@
3232
from pytensor.tensor.exceptions import NotScalarConstantError
3333

3434

35+
MLX_DYNAMIC_SHAPE_ERROR = (
36+
"MLX compilation limitation: Alloc operations with dynamic shapes "
37+
"cannot be used inside compiled functions. This is because MLX "
38+
"compilation forbids evaluating arrays to extract shape values. "
39+
"\n\nWorkarounds:"
40+
"\n1. Avoid using Alloc with dynamic shapes in compiled contexts"
41+
"\n2. Use static shapes when possible"
42+
"\n3. Move Alloc operations outside compiled functions"
43+
)
44+
45+
3546
@mlx_funcify.register(Join)
3647
def mlx_funcify_Join(op, **kwargs):
3748
def join(axis, *tensors):
@@ -247,33 +258,66 @@ def allocempty(*shape):
247258

248259
@mlx_funcify.register(Alloc)
249260
def mlx_funcify_Alloc(op, node, **kwargs):
261+
node_inputs = getattr(node, "inputs", None)
262+
static_dims = (
263+
_extract_static_dims(node_inputs[1:])
264+
if node_inputs and len(node_inputs) > 1
265+
else None
266+
)
267+
250268
def alloc(x, *shape):
251-
try:
252-
# Convert shape elements to Python ints for MLX compatibility
253-
# MLX requires shape dimensions to be Python integers, not MLX arrays
254-
shape_ints = tuple(
255-
int(s.item()) if hasattr(s, "item") else int(s) for s in shape
256-
)
257-
return mx.broadcast_to(x, shape_ints)
258-
except ValueError as e:
259-
if (
260-
"[eval] Attempting to eval an array during function transformations"
261-
in str(e)
262-
):
263-
# This is the MLX compilation limitation - provide helpful error
264-
raise ValueError(
265-
"MLX compilation limitation: Alloc operations with dynamic shapes "
266-
"cannot be used inside compiled functions. This is because MLX "
267-
"compilation forbids evaluating arrays to extract shape values. "
268-
# Just a note! TODO: remove this once we have a better solution
269-
"\n\nWorkarounds:"
270-
"\n1. Avoid using Alloc with dynamic shapes in compiled contexts"
271-
"\n2. Use static shapes when possible"
272-
"\n3. Move Alloc operations outside compiled functions"
273-
"\n\nOriginal error: " + str(e)
274-
) from e
275-
else:
276-
# Re-raise other ValueError exceptions
277-
raise
269+
resolved_shape = (
270+
_resolve_shape(static_dims, shape)
271+
if static_dims is not None
272+
else tuple(_coerce_to_int(dim) for dim in shape)
273+
)
274+
result = mx.broadcast_to(x, resolved_shape)
275+
if node_inputs is not None:
276+
value_for_check = x if hasattr(x, "shape") else np.asarray(x)
277+
Alloc._check_runtime_broadcast(node, value_for_check, resolved_shape)
278+
return result
278279

279280
return alloc
281+
282+
283+
def _extract_static_dims(shape_inputs):
284+
static_dims = []
285+
for dim in shape_inputs:
286+
try:
287+
static_dims.append(int(get_scalar_constant_value(dim)))
288+
except NotScalarConstantError:
289+
static_dims.append(None)
290+
return tuple(static_dims)
291+
292+
293+
def _resolve_shape(static_dims, runtime_shape):
294+
if len(static_dims) != len(runtime_shape):
295+
raise ValueError("Alloc received unexpected number of shape dimensions")
296+
297+
resolved = []
298+
for const_dim, dim in zip(static_dims, runtime_shape, strict=True):
299+
resolved.append(const_dim if const_dim is not None else _coerce_to_int(dim))
300+
301+
return tuple(resolved)
302+
303+
304+
def _coerce_to_int(value):
305+
if isinstance(value, np.integer | int):
306+
return int(value)
307+
try:
308+
if hasattr(value, "item"):
309+
return int(value.item())
310+
return int(value)
311+
except (ValueError, TypeError) as exc:
312+
_rethrow_dynamic_shape_error(exc)
313+
raise
314+
raise TypeError(
315+
"MLX Alloc expects integer shape components; got value of type "
316+
f"{type(value).__name__}."
317+
)
318+
319+
320+
def _rethrow_dynamic_shape_error(exc):
321+
msg = str(exc)
322+
if "[eval] Attempting to eval an array during function transformations" in msg:
323+
raise ValueError(f"{MLX_DYNAMIC_SHAPE_ERROR}\n\nOriginal error: {msg}") from exc

pytensor/link/mlx/dispatch/math.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
4040
from pytensor.tensor.elemwise import Elemwise
41-
from pytensor.tensor.math import Dot
41+
from pytensor.tensor.math import Argmax, Dot, Max
4242

4343

4444
@mlx_funcify.register(Dot)
@@ -49,6 +49,66 @@ def dot(x, y):
4949
return dot
5050

5151

52+
@mlx_funcify.register(Max)
53+
def mlx_funcify_Max(op, node=None, **kwargs):
54+
def max_fn(x):
55+
axes = op.axis
56+
if axes is None:
57+
reduce_axes = None
58+
else:
59+
reduce_axes = tuple(int(ax) for ax in axes)
60+
61+
keepdims = getattr(op, "keepdims", False)
62+
63+
return mx.max(x, axis=reduce_axes, keepdims=keepdims)
64+
65+
return max_fn
66+
67+
68+
@mlx_funcify.register(Argmax)
69+
def mlx_funcify_Argmax(op, node=None, **kwargs):
70+
axis = op.axis
71+
72+
def argmax_fn(x):
73+
if axis is None:
74+
axes = tuple(range(x.ndim))
75+
else:
76+
axes = tuple(int(ax) for ax in axis)
77+
78+
keep_axes = [i for i in range(x.ndim) if i not in axes]
79+
transposed_x = mx.transpose(x, tuple(keep_axes + list(axes)))
80+
81+
kept_shape = transposed_x.shape[: len(keep_axes)]
82+
reduced_shape = transposed_x.shape[len(keep_axes) :]
83+
84+
flat_size = 1
85+
for dim in reduced_shape:
86+
flat_size *= int(dim)
87+
reshaped_x = transposed_x.reshape((*kept_shape, flat_size))
88+
89+
max_idx = mx.argmax(reshaped_x, axis=-1)
90+
91+
result = max_idx.astype(mx.int64)
92+
93+
if getattr(op, "keepdims", False):
94+
reshape_shape = []
95+
keep_iter = iter(kept_shape)
96+
axis_iter = iter(sorted(axes))
97+
next_axis = next(axis_iter, None)
98+
for dim_idx in range(x.ndim):
99+
if next_axis is not None and dim_idx == next_axis:
100+
reshape_shape.append(1)
101+
next_axis = next(axis_iter, None)
102+
else:
103+
reshape_shape.append(int(next(keep_iter)))
104+
105+
return result.reshape(tuple(reshape_shape))
106+
107+
return result
108+
109+
return argmax_fn
110+
111+
52112
# Second-level dispatch for scalar operations in Elemwise
53113
@singledispatch
54114
def mlx_funcify_Elemwise_scalar_op(scalar_op):

0 commit comments

Comments
 (0)