Skip to content

Commit ff56b7f

Browse files
committed
Fix python implementation of IncSubtensor
1 parent 892a8f0 commit ff56b7f

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

pytensor/tensor/subtensor.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,41 +1756,30 @@ def make_node(self, x, y, *inputs):
17561756
def decl_view(self):
17571757
return "PyArrayObject * zview = NULL;"
17581758

1759-
def perform(self, node, inputs, out_):
1760-
(out,) = out_
1761-
x, y = inputs[:2]
1762-
indices = list(reversed(inputs[2:]))
1763-
1764-
def _convert(entry):
1765-
if isinstance(entry, Type):
1766-
return indices.pop()
1767-
elif isinstance(entry, slice):
1768-
return slice(
1769-
_convert(entry.start), _convert(entry.stop), _convert(entry.step)
1759+
def perform(self, node, inputs, output_storage):
1760+
x, y, *flat_indices = inputs
1761+
1762+
flat_indices_iterator = iter(flat_indices)
1763+
indices = tuple(
1764+
(
1765+
next(flat_indices_iterator)
1766+
if isinstance(entry, Type)
1767+
else slice(
1768+
None if entry.start is None else next(flat_indices_iterator),
1769+
None if entry.stop is None else next(flat_indices_iterator),
1770+
None if entry.step is None else next(flat_indices_iterator),
17701771
)
1771-
else:
1772-
return entry
1772+
)
1773+
for entry in self.idx_list
1774+
)
17731775

1774-
cdata = tuple(map(_convert, self.idx_list))
1775-
if len(cdata) == 1:
1776-
cdata = cdata[0]
17771776
if not self.inplace:
17781777
x = x.copy()
1779-
sub_x = x.__getitem__(cdata)
1780-
if sub_x.shape:
1781-
# we've sliced out an N-D tensor with N > 0
1782-
if not self.set_instead_of_inc:
1783-
sub_x += y
1784-
else:
1785-
# sub_x += -sub_x + y
1786-
x.__setitem__(cdata, y)
1778+
if self.set_instead_of_inc:
1779+
x[indices] = y
17871780
else:
1788-
# scalar case
1789-
if not self.set_instead_of_inc:
1790-
x.__setitem__(cdata, sub_x + y)
1791-
else:
1792-
x.__setitem__(cdata, y)
1793-
out[0] = x
1781+
x[indices] += y
1782+
output_storage[0][0] = x
17941783

17951784
def c_code(self, node, name, inputs, outputs, sub):
17961785
# This method delegates much of the work to helper

0 commit comments

Comments
 (0)