Skip to content

Commit bacdaf6

Browse files
committed
Fix AdvancedIncSubtensor1 C-compilation with empty indices
1 parent 12213d0 commit bacdaf6

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

pytensor/tensor/subtensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,7 +2561,11 @@ def c_code(self, node, name, input_names, output_names, sub):
25612561
and y_.type.dtype not in complex_dtypes
25622562
):
25632563
# Simple implementation for vector x, y cases
2564-
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2564+
idx_may_be_neg = not (
2565+
# Empty idx needs no negative checks
2566+
idx_.type.shape[0] == 0
2567+
or (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2568+
)
25652569
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
25662570
shape0 = x_.type.shape[0]
25672571
# This is used to make sure that when we trust the indices to be valid
@@ -2680,7 +2684,7 @@ def c_code(self, node, name, input_names, output_names, sub):
26802684
"""
26812685

26822686
def c_code_cache_version(self):
2683-
return (9,)
2687+
return (10,)
26842688

26852689
def _check_runtime_broadcasting(
26862690
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray

tests/tensor/test_subtensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytensor.graph.rewriting.utils import is_same_graph
2323
from pytensor.printing import pprint
2424
from pytensor.scalar.basic import as_scalar, int16
25-
from pytensor.tensor import as_tensor, get_vector_length, vectorize
25+
from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
2626
from pytensor.tensor.blockwise import Blockwise
2727
from pytensor.tensor.elemwise import DimShuffle
2828
from pytensor.tensor.math import exp, isinf, lt, switch
@@ -1730,7 +1730,7 @@ def test_grad_broadcastable_specialization(self):
17301730
)
17311731

17321732

1733-
class TestIncSubtensor1:
1733+
class TestAdvancedIncSubtensor1:
17341734
def setup_method(self):
17351735
self.rng = np.random.default_rng(seed=utt.fetch_seed())
17361736

@@ -1817,6 +1817,16 @@ def test_inc_bcastableidx(self):
18171817
out1val, out2val = f(mval, incval, incval)
18181818
utt.assert_allclose(out1val, out2val)
18191819

1820+
def test_empty_index(self):
1821+
x = fvector()
1822+
idx = constant([], dtype="int64")
1823+
y = idx.astype("float32")
1824+
out = advanced_inc_subtensor1(x, y, idx)
1825+
1826+
test_x = np.array([1, 2, 3], dtype="float32")
1827+
res = out.eval({x: test_x}, mode=Mode(optimizer=None))
1828+
np.testing.assert_array_equal(res, test_x)
1829+
18201830

18211831
class TestAdvancedSubtensor:
18221832
"""Test inc_subtensor and set_subtensor."""

0 commit comments

Comments
 (0)