Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 60 additions & 50 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,6 @@
if isinstance(idx.type, TensorType)
]

def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False

# Special implementation for consecutive integer vector indices
if (
not basic_idxs
Expand All @@ -151,17 +142,6 @@
)
# Must be consecutive
and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)

Expand Down Expand Up @@ -191,14 +171,24 @@
return numba_funcify_default_subtensor(op, node, **kwargs)


def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True

Check warning on line 180 in pytensor/link/numba/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/subtensor.py#L180

Added line #L180 was not covered by tests
return False


def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:]
idxs = node.inputs[1:]
else:
y, *idxs = node.inputs[1:]
idxs = node.inputs[2:]

first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
Expand All @@ -211,6 +201,10 @@
)
except StopIteration:
after_last_axis = len(idxs)
last_axis = after_last_axis - 1

vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices)

if isinstance(op, AdvancedSubtensor):

Expand All @@ -231,43 +225,59 @@

return advanced_subtensor_multiple_vector

elif op.set_instead_of_inc:
else:
inplace = op.inplace

@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
# Check if y must be broadcasted
# Includes the last integer vector index,
x, y = node.inputs[:2]
indexed_bcast_dims = (
*x.type.broadcastable[:first_axis],
*x.type.broadcastable[last_axis:],
)
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)

if inplace:
out = x
else:
out = x.copy()
if op.set_instead_of_inc:

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out
@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

return advanced_set_subtensor_multiple_vector
if inplace:
out = x
else:
out = x.copy()

else:
inplace = op.inplace
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])

@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out

return advanced_set_subtensor_multiple_vector

else:

@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

if inplace:
out = x
else:
out = x.copy()

if inplace:
out = x
else:
out = x.copy()
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out

return advanced_inc_subtensor_multiple_vector

Expand Down
4 changes: 2 additions & 2 deletions tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices):
np.array(-99), # Broadcasted value
([1, 2], [2, 3]), # 2 vector indices
False,
True,
True,
False,
False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we concoct a new test case that still requires object on both inc and set? Or is the coverage now that good?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 3 such tests already:

(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
False,
True,
True,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 2)),
(
slice(1, None),
[1, 2],
[3, 4],
), # Mixed double vector index and basic index
False,
True,
True,
),

np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
(slice(1, 3), [[1, 2], [2, 3]]),
False,
True,
True,
),

Either when array indexes are non-consecutive, or mixed with basic indices

),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
Expand Down