diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 81348b57be..ee9e183d16 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(idx.type, TensorType) ] - def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): - # Check that x is not broadcasted to y based on broadcastable info - if len(x_bcast) < len(to_bcast): - return True - for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True): - if x_bcast_dim and not to_bcast_dim: - return True - return False - # Special implementation for consecutive integer vector indices if ( not basic_idxs @@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): ) # Must be consecutive and not op.non_consecutive_adv_indexing(node) - # y in set/inc_subtensor cannot be broadcasted - and ( - y is None - or not broadcasted_to( - y.type.broadcastable, - ( - x.type.broadcastable[: adv_idxs[0]["axis"]] - + x.type.broadcastable[adv_idxs[-1]["axis"] :] - ), - ) - ) ): return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs) @@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): return numba_funcify_default_subtensor(op, node, **kwargs) +def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): + # Check that x is not broadcasted to y based on broadcastable info + if len(x_bcast) < len(to_bcast): + return True + for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True): + if x_bcast_dim and not to_bcast_dim: + return True + return False + + def numba_funcify_multiple_integer_vector_indexing( op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs ): # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor) if isinstance(op, AdvancedSubtensor): - y, idxs = None, node.inputs[1:] + idxs = node.inputs[1:] else: - y, *idxs = node.inputs[1:] + idxs = node.inputs[2:] first_axis = next( i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) @@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing( ) except StopIteration: after_last_axis = len(idxs) + last_axis = after_last_axis - 1 + + vector_indices = idxs[first_axis:after_last_axis] + assert all(v.type.broadcastable == (False,) for v in vector_indices) if isinstance(op, AdvancedSubtensor): @@ -231,43 +225,59 @@ def advanced_subtensor_multiple_vector(x, *idxs): return advanced_subtensor_multiple_vector - elif op.set_instead_of_inc: + else: inplace = op.inplace - @numba_njit - def advanced_set_subtensor_multiple_vector(x, y, *idxs): - vec_idxs = idxs[first_axis:after_last_axis] - x_shape = x.shape + # Check if y must be broadcasted + # Includes the last integer vector index, + x, y = node.inputs[:2] + indexed_bcast_dims = ( + *x.type.broadcastable[:first_axis], + *x.type.broadcastable[last_axis:], + ) + y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims) - if inplace: - out = x - else: - out = x.copy() + if op.set_instead_of_inc: - for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 - out[(*outer, *scalar_idxs)] = y[(*outer, i)] - return out + @numba_njit + def advanced_set_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape - return advanced_set_subtensor_multiple_vector + if inplace: + out = x + else: + out = x.copy() - else: - inplace = op.inplace + if y_is_broadcasted: + y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) - @numba_njit - def advanced_inc_subtensor_multiple_vector(x, y, *idxs): - vec_idxs = idxs[first_axis:after_last_axis] - x_shape = x.shape + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] = y[(*outer, i)] + return out + + return advanced_set_subtensor_multiple_vector + + else: + + @numba_njit + def advanced_inc_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape + + if inplace: + out = x + else: + out = x.copy() - if inplace: - out = x - else: - out = x.copy() + if y_is_broadcasted: + y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) - for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 - out[(*outer, *scalar_idxs)] += y[(*outer, i)] - return out + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] += y[(*outer, i)] + return out return advanced_inc_subtensor_multiple_vector diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 675afdc996..c9578657f2 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices): np.array(-99), # Broadcasted value ([1, 2], [2, 3]), # 2 vector indices False, - True, - True, + False, + False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)),