From a0c3b5a47761a16c22b4623aeb18b2e42a4a9a46 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Sun, 16 Nov 2025 19:14:31 +0200 Subject: [PATCH] Solving issue 1724 --- pytensor/link/mlx/dispatch/core.py | 23 ++++++++++++++--------- tests/link/mlx/test_core.py | 9 +++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index be3ed37e3a..1452cdf751 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -47,6 +47,13 @@ def mlx_funcify_Split(op: Split, node, **kwargs): except NotScalarConstantError: constant_axis = None + # Reject symbolic axes at dispatch time so that unsupported configurations + # do not produce Python exceptions from inside MLX-compiled regions. + # This mirrors the existing limitation but avoids the CI abort seen in + # https://github.com/pymc-devs/pytensor/issues/1724. + if constant_axis is None: + raise ValueError("Symbolic axis is not supported in MLX Split implementation.") + try: constant_splits = np.array( [ @@ -59,28 +66,26 @@ def mlx_funcify_Split(op: Split, node, **kwargs): def split(x, axis, splits): # Resolve constants for significant performance improvement (14x speedup) - if constant_axis is not None: - axis = int(constant_axis) - else: - raise ValueError( - "Symbolic axis is not supported in MLX Split implementation." - ) + axis = int(constant_axis) if constant_splits is not None: splits_arr = mx.array(constant_splits) + splits_for_validation = constant_splits else: splits_arr = mx.array(splits) + splits_for_validation = splits cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # Validation checks - if len(splits) != op.len_splits: + splits_np = np.asarray(splits_for_validation) + if len(splits_np) != op.len_splits: raise ValueError("Length of 'splits' is not equal to n_splits") - if np.sum(np.asarray(splits)) != x.shape[axis]: + if np.sum(splits_np) != x.shape[axis]: raise ValueError( "Split sizes do not sum to the input length on the chosen axis." ) - if np.any(np.asarray(splits) < 0): + if np.any(splits_np < 0): raise ValueError("Split sizes cannot be negative.") return mx.split(x, cumsum_splits, axis=axis) diff --git a/tests/link/mlx/test_core.py b/tests/link/mlx/test_core.py index d50d3a9959..7875bfc12c 100644 --- a/tests/link/mlx/test_core.py +++ b/tests/link/mlx/test_core.py @@ -164,3 +164,12 @@ def test_split_dynamic_axis_const_splits(): ValueError, match="Symbolic axis is not supported in MLX Split implementation" ): compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)]) + + +def test_split_invalid_splits_len_mlx(): + x = pt.vector("x") + # len_splits=3, but only provide 2 split sizes + splits = [2, 2] + # The core Split op itself raises a ValueError at graph construction time + with pytest.raises(ValueError, match="Number of splits is larger than splits size"): + pt.split(x, splits, 3, axis=0)