Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(idx.type, TensorType)
]

# Special case for consecutive consecutive vector indices
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
Expand Down Expand Up @@ -176,7 +175,14 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not op.ignore_duplicates
and not (
op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
or all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool"
for adv_idx in adv_idxs
)
)
)
):
return generate_fallback_impl(op, node, **kwargs)
Expand Down
8 changes: 8 additions & 0 deletions tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,16 @@ def test_AdvancedIncSubtensor1(x, y, indices):
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index
False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index
True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False,
False,
),
(
Expand Down