Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions pytensor/link/mlx/dispatch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
except NotScalarConstantError:
constant_axis = None

# Reject symbolic axes at dispatch time so that unsupported configurations
# do not produce Python exceptions from inside MLX-compiled regions.
# This mirrors the existing limitation but avoids the CI abort seen in
# https://github.com/pymc-devs/pytensor/issues/1724.
if constant_axis is None:
raise ValueError("Symbolic axis is not supported in MLX Split implementation.")

try:
constant_splits = np.array(
[
Expand All @@ -59,28 +66,26 @@ def mlx_funcify_Split(op: Split, node, **kwargs):

def split(x, axis, splits):
# Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None:
axis = int(constant_axis)
else:
raise ValueError(
"Symbolic axis is not supported in MLX Split implementation."
)
axis = int(constant_axis)

if constant_splits is not None:
splits_arr = mx.array(constant_splits)
splits_for_validation = constant_splits
else:
splits_arr = mx.array(splits)
splits_for_validation = splits

cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()

# Validation checks
if len(splits) != op.len_splits:
splits_np = np.asarray(splits_for_validation)
if len(splits_np) != op.len_splits:
raise ValueError("Length of 'splits' is not equal to n_splits")
if np.sum(np.asarray(splits)) != x.shape[axis]:
if np.sum(splits_np) != x.shape[axis]:
raise ValueError(
"Split sizes do not sum to the input length on the chosen axis."
)
if np.any(np.asarray(splits) < 0):
if np.any(splits_np < 0):
raise ValueError("Split sizes cannot be negative.")

return mx.split(x, cumsum_splits, axis=axis)
Expand Down
9 changes: 9 additions & 0 deletions tests/link/mlx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,12 @@ def test_split_dynamic_axis_const_splits():
ValueError, match="Symbolic axis is not supported in MLX Split implementation"
):
compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)])


def test_split_invalid_splits_len_mlx():
x = pt.vector("x")
# len_splits=3, but only provide 2 split sizes
splits = [2, 2]
# The core Split op itself raises a ValueError at graph construction time
with pytest.raises(ValueError, match="Number of splits is larger than splits size"):
pt.split(x, splits, 3, axis=0)
Loading