Skip to content
Open
20 changes: 19 additions & 1 deletion pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,25 @@ def mlx_fn(x, indices, y):
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
def get_slice_int(element):
if element is None:
return None
try:
return int(element)
except Exception:
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare except Exception is too broad. This should catch specific exceptions like TypeError or ValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.

Suggested change
except Exception:
except (TypeError, ValueError):

Copilot uses AI. Check for mistakes.
return element

indices = tuple(
[
slice(
get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step)
)
if isinstance(s, slice)
else s
for s in indices_from_subtensor(ilist, idx_list)
]
)

if len(indices) == 1:
indices = indices[0]

Expand Down
13 changes: 13 additions & 0 deletions tests/link/mlx/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment():
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])

# Increment slice
out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt)
compare_mlx_and_py([], [out_pt], [])


def test_mlx_AdvancedIncSubtensor_set():
"""Test advanced set operations using AdvancedIncSubtensor."""
Expand Down
Loading