|
22 | 22 | from pytensor.graph.rewriting.utils import is_same_graph
|
23 | 23 | from pytensor.printing import pprint
|
24 | 24 | from pytensor.scalar.basic import as_scalar, int16
|
25 |
| -from pytensor.tensor import as_tensor, get_vector_length, vectorize |
| 25 | +from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize |
26 | 26 | from pytensor.tensor.blockwise import Blockwise
|
27 | 27 | from pytensor.tensor.elemwise import DimShuffle
|
28 | 28 | from pytensor.tensor.math import exp, isinf, lt, switch
|
@@ -1730,7 +1730,7 @@ def test_grad_broadcastable_specialization(self):
|
1730 | 1730 | )
|
1731 | 1731 |
|
1732 | 1732 |
|
1733 |
| -class TestIncSubtensor1: |
| 1733 | +class TestAdvancedIncSubtensor1: |
1734 | 1734 | def setup_method(self):
|
1735 | 1735 | self.rng = np.random.default_rng(seed=utt.fetch_seed())
|
1736 | 1736 |
|
@@ -1817,6 +1817,16 @@ def test_inc_bcastableidx(self):
|
1817 | 1817 | out1val, out2val = f(mval, incval, incval)
|
1818 | 1818 | utt.assert_allclose(out1val, out2val)
|
1819 | 1819 |
|
| 1820 | + def test_empty_index(self): |
| 1821 | + x = fvector() |
| 1822 | + idx = constant([], dtype="int64") |
| 1823 | + y = idx.astype("float32") |
| 1824 | + out = advanced_inc_subtensor1(x, y, idx) |
| 1825 | + |
| 1826 | + test_x = np.array([1, 2, 3], dtype="float32") |
| 1827 | + res = out.eval({x: test_x}, mode=Mode(optimizer=None)) |
| 1828 | + np.testing.assert_array_equal(res, test_x) |
| 1829 | + |
1820 | 1830 |
|
1821 | 1831 | class TestAdvancedSubtensor:
|
1822 | 1832 | """Test inc_subtensor and set_subtensor."""
|
|
0 commit comments