Skip to content

Commit 5ba887b

Browse files
committed
Fix rebase
1 parent 26bc88a commit 5ba887b

File tree

7 files changed

+153
-69
lines changed

7 files changed

+153
-69
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -244,37 +244,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
244244
else:
245245
tensor_inputs = node.inputs[2:]
246246

247-
adv_idxs = [
248-
{
249-
"axis": i,
250-
"dtype": idx.type.dtype,
251-
"bcast": idx.type.broadcastable,
252-
"ndim": idx.type.ndim,
253-
}
254-
for i, idx in enumerate(idxs)
255-
if isinstance(idx.type, TensorType)
256-
]
257-
258247
# Reconstruct indexing information from idx_list and tensor inputs
259-
# basic_idxs = []
260-
# adv_idxs = []
261-
# input_idx = 0
262-
#
263-
# for i, entry in enumerate(op.idx_list):
264-
# if isinstance(entry, slice):
265-
# # Basic slice index
266-
# basic_idxs.append(entry)
267-
# elif isinstance(entry, Type):
268-
# # Advanced tensor index
269-
# if input_idx < len(tensor_inputs):
270-
# idx_input = tensor_inputs[input_idx]
271-
# adv_idxs.append({
272-
# "axis": i,
273-
# "dtype": idx_input.type.dtype,
274-
# "bcast": idx_input.type.broadcastable,
275-
# "ndim": idx_input.type.ndim,
276-
# })
277-
# input_idx += 1
278248
basic_idxs = []
279249
adv_idxs = []
280250
input_idx = 0
@@ -313,7 +283,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
313283
and len(adv_idxs) >= 1
314284
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
315285
# Implementation does not support newaxis
316-
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
286+
and not any(isinstance(idx.type, NoneTypeT) for idx in tensor_inputs)
317287
):
318288
return vector_integer_advanced_indexing(op, node, **kwargs)
319289

pytensor/tensor/rewriting/subtensor.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
IncSubtensor,
7474
Subtensor,
7575
advanced_inc_subtensor1,
76-
advanced_subtensor,
7776
advanced_subtensor1,
7877
as_index_constant,
7978
get_canonical_form_slice,
@@ -83,7 +82,7 @@
8382
inc_subtensor,
8483
indices_from_subtensor,
8584
)
86-
from pytensor.tensor.type import TensorType
85+
from pytensor.tensor.type import TensorType, integer_dtypes
8786
from pytensor.tensor.type_other import NoneTypeT, SliceType
8887
from pytensor.tensor.variable import TensorConstant, TensorVariable
8988

@@ -265,6 +264,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
265264
"""
266265

267266
if type(node.op) is not AdvancedIncSubtensor:
267+
# Don't apply to subclasses
268268
return
269269

270270
if node.op.ignore_duplicates:
@@ -1321,7 +1321,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13211321
if isinstance(node.op, IncSubtensor):
13221322
xi = Subtensor(node.op.idx_list)(x, *i)
13231323
elif isinstance(node.op, AdvancedIncSubtensor):
1324-
xi = advanced_subtensor(x, *i)
1324+
# Use the same idx_list as the original operation to ensure correct shape
1325+
op = AdvancedSubtensor(node.op.idx_list)
1326+
xi = op.make_node(x, *i).outputs[0]
13251327
elif isinstance(node.op, AdvancedIncSubtensor1):
13261328
xi = advanced_subtensor1(x, *i)
13271329
else:
@@ -1771,10 +1773,11 @@ def local_blockwise_inc_subtensor(fgraph, node):
17711773

17721774

17731775
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
1774-
def bool_idx_to_nonzero(fgraph, node):
1775-
"""Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
1776+
def ravel_multidimensional_bool_idx(fgraph, node):
1777+
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
17761778
1777-
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
1779+
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1780+
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17781781
"""
17791782

17801783
if isinstance(node.op, AdvancedSubtensor):
@@ -1787,26 +1790,53 @@ def bool_idx_to_nonzero(fgraph, node):
17871790
# Reconstruct indices from idx_list and tensor inputs
17881791
idxs = indices_from_subtensor(tensor_inputs, node.op.idx_list)
17891792

1790-
bool_pos = {
1791-
i
1793+
if any(
1794+
(
1795+
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
1796+
or isinstance(idx.type, NoneTypeT)
1797+
)
1798+
for idx in idxs
1799+
):
1800+
# Get out if there are any other advanced indexes or np.newaxis
1801+
return None
1802+
1803+
bool_idxs = [
1804+
(i, idx)
17921805
for i, idx in enumerate(idxs)
17931806
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
1794-
}
1807+
]
17951808

1796-
if not bool_pos:
1809+
if len(bool_idxs) != 1:
1810+
# Get out if there are no or multiple boolean idxs
1811+
return None
1812+
[(bool_idx_pos, bool_idx)] = bool_idxs
1813+
bool_idx_ndim = bool_idx.type.ndim
1814+
if bool_idx.type.ndim < 2:
1815+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
17971816
return None
17981817

1799-
new_idxs = []
1800-
for i, idx in enumerate(idxs):
1801-
if i in bool_pos:
1802-
new_idxs.extend(idx.nonzero())
1803-
else:
1804-
new_idxs.append(idx)
1818+
x_shape = x.shape
1819+
raveled_x = x.reshape(
1820+
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
1821+
)
1822+
1823+
raveled_bool_idx = bool_idx.ravel()
1824+
new_idxs = list(idxs)
1825+
new_idxs[bool_idx_pos] = raveled_bool_idx
18051826

18061827
if isinstance(node.op, AdvancedSubtensor):
1807-
new_out = node.op(x, *new_idxs)
1828+
new_out = raveled_x[tuple(new_idxs)]
18081829
else:
1809-
new_out = node.op(x, y, *new_idxs)
1830+
sub = raveled_x[tuple(new_idxs)]
1831+
new_out = inc_subtensor(
1832+
sub,
1833+
y,
1834+
set_instead_of_inc=node.op.set_instead_of_inc,
1835+
ignore_duplicates=node.op.ignore_duplicates,
1836+
inplace=node.op.inplace,
1837+
)
1838+
new_out = new_out.reshape(x_shape)
1839+
18101840
return [copy_stack_trace(node.outputs[0], new_out)]
18111841

18121842

@@ -1941,10 +1971,16 @@ def ravel_multidimensional_int_idx(fgraph, node):
19411971

19421972

19431973
optdb["specialize"].register(
1944-
bool_idx_to_nonzero.__name__,
1945-
bool_idx_to_nonzero,
1974+
ravel_multidimensional_bool_idx.__name__,
1975+
ravel_multidimensional_bool_idx,
1976+
"numba",
1977+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
1978+
)
1979+
1980+
optdb["specialize"].register(
1981+
ravel_multidimensional_int_idx.__name__,
1982+
ravel_multidimensional_int_idx,
19461983
"numba",
1947-
"shape_unsafe", # It can mask invalid mask sizes
19481984
use_db_name_as_tag=False, # Not included if only "specialize" is requested
19491985
)
19501986

pytensor/tensor/subtensor.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,12 @@ def __init__(self, idx_list=None):
922922

923923
def _normalize_idx_list_for_hash(self):
924924
"""Normalize idx_list for hash and equality comparison."""
925-
if self.idx_list is None:
925+
idx_list = getattr(self, "idx_list", None)
926+
if idx_list is None:
926927
return None
927928

928929
msg = []
929-
for entry in self.idx_list:
930+
for entry in idx_list:
930931
if isinstance(entry, slice):
931932
msg.append((entry.start, entry.stop, entry.step))
932933
else:
@@ -2812,13 +2813,6 @@ def make_node(self, x, *inputs):
28122813
advanced_indices = []
28132814
adv_group_axis = None
28142815
last_adv_group_axis = None
2815-
if new_axes: #not defined?
2816-
expanded_x_shape_list = list(x.type.shape)
2817-
for new_axis in new_axes:
2818-
expanded_x_shape_list.insert(new_axis, 1)
2819-
expanded_x_shape = tuple(expanded_x_shape_list)
2820-
else:
2821-
expanded_x_shape = x.type.shape
28222816
for i, (idx, dim_length) in enumerate(
28232817
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None))
28242818
):

tests/tensor/rewriting/test_basic.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,63 @@ def test_incsubtensor(self):
468468
assert check_stack_trace(f1, ops_to_check="last")
469469
assert check_stack_trace(f2, ops_to_check="last")
470470

471+
def test_advanced_inc_subtensor_shape_inference_bug(self):
472+
"""
473+
Test for bug in local_useless_inc_subtensor_alloc where advanced_subtensor
474+
was called instead of using the original op's idx_list, causing incorrect
475+
shape inference and AssertionError.
476+
477+
The bug occurred when advanced_subtensor(x, *i) tried to reconstruct
478+
idx_list from inputs, leading to wrong shape for xi. This caused the
479+
Assert condition checking shape compatibility to fail at runtime with:
480+
AssertionError: `x[i]` and `y` do not have the same shape.
481+
482+
This test reproduces the bug by using a scenario where the shape
483+
comparison would fail if xi has the wrong shape due to incorrect
484+
idx_list reconstruction.
485+
"""
486+
# Use vector with matrix indices - this creates AdvancedIncSubtensor
487+
# The key is that when advanced_subtensor tries to reconstruct idx_list,
488+
# it may get it wrong, causing xi to have incorrect shape
489+
x = vector("x")
490+
y = scalar("y")
491+
i = matrix(
492+
"i", dtype="int64"
493+
) # 2D indices for 1D array -> AdvancedIncSubtensor
494+
495+
# Create AdvancedIncSubtensor with Alloc
496+
# When i is (n, m), i.shape is (n, m), so alloc creates shape (n, m)
497+
# But x[i] where i is (n, m) creates shape (n, m) as well
498+
# The bug would cause xi to have wrong shape, making the Assert fail
499+
z = advanced_inc_subtensor(x, pt.alloc(y, *i.shape), i)
500+
501+
# Compile - this should not raise AssertionError during execution
502+
# With the buggy code (using advanced_subtensor), this raises:
503+
# AssertionError: `x[i]` and `y` do not have the same shape.
504+
f = function([x, i, y], z, mode=self.mode)
505+
506+
# Test with actual values
507+
x_value = np.random.standard_normal(10).astype(config.floatX)
508+
y_value = np.random.standard_normal()
509+
i_value = self.rng.integers(0, 10, size=(3, 2))
510+
511+
# This should execute without AssertionError
512+
# With the buggy code (using advanced_subtensor), this would raise:
513+
# AssertionError: `x[i]` and `y` do not have the same shape.
514+
result = f(x_value, i_value, y_value)
515+
516+
# Verify basic properties
517+
# The main point of this test is that it doesn't raise AssertionError
518+
# advanced_inc_subtensor modifies x in place and returns it
519+
assert result.shape == x_value.shape, "Result should have same shape as input"
520+
assert not np.array_equal(result, x_value), "Result should be modified"
521+
522+
# Verify the rewrite was applied (Alloc should be removed)
523+
topo = f.maker.fgraph.toposort()
524+
assert len([n for n in topo if isinstance(n.op, Alloc)]) == 0, (
525+
"Alloc should have been removed by the rewrite"
526+
)
527+
471528

472529
class TestUselessCheckAndRaise:
473530
def test_basic(self):

tests/tensor/rewriting/test_elemwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,9 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
16421642
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
16431643
rewrite_graph(fgraph, include=("inplace",))
16441644

1645-
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
1646-
with pytest.warns(
1647-
FutureWarning,
1648-
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
1649-
):
1650-
rewrite_graph(fgraph, include=("inplace",))
1645+
with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=1):
1646+
with pytest.warns(
1647+
FutureWarning,
1648+
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
1649+
):
1650+
rewrite_graph(fgraph, include=("inplace",))

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def test_masked_array_not_implemented(
705705

706706

707707
def check_alloc_runtime_broadcast(mode):
708-
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
708+
"""Check we emit a clear error when runtime broadcasting would occur according to Numpy rules."""
709709
floatX = config.floatX
710710
x_v = vector("x", shape=(None,))
711711

tests/tensor/test_subtensor.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor
1212
import pytensor.scalar as scal
1313
import pytensor.tensor.basic as ptb
14-
from pytensor import config, function, shared
14+
from pytensor import function, shared
1515
from pytensor.compile import DeepCopyOp
1616
from pytensor.compile.io import In
1717
from pytensor.compile.mode import Mode, get_default_mode
@@ -3301,6 +3301,33 @@ def test_slice_at_axis():
33013301
assert x_sliced.type.shape == (3, 1, 5)
33023302

33033303

3304+
def test_advanced_inc_subtensor1_failure():
3305+
# Shapes from the failure log
3306+
N = 500
3307+
TotalCols = 7
3308+
OrderedCols = 5
3309+
UnorderedCols = 2
3310+
3311+
oinds_val = [1, 2, 3, 5, 6]
3312+
uoinds_val = [0, 4]
3313+
3314+
y_ordered = matrix("y_ordered")
3315+
y_unordered = matrix("y_unordered")
3316+
3317+
fodds_init = ptb.empty((N, TotalCols))
3318+
3319+
fodds_step1 = set_subtensor(fodds_init[:, uoinds_val], y_unordered)
3320+
fodds_step2 = set_subtensor(fodds_step1[:, oinds_val], y_ordered)
3321+
3322+
f = pytensor.function([y_unordered, y_ordered], fodds_step2)
3323+
# assert any("AdvancedIncSubtensor1" in str(node) for node in f.maker.fgraph.toposort())
3324+
3325+
y_u_data = np.random.randn(N, UnorderedCols).astype(np.float64)
3326+
y_o_data = np.random.randn(N, OrderedCols).astype(np.float64)
3327+
res = f(y_u_data, y_o_data)
3328+
assert res.shape == (N, TotalCols)
3329+
3330+
33043331
@pytest.mark.parametrize(
33053332
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
33063333
)

0 commit comments

Comments
 (0)