Skip to content

Commit 8f344da

Browse files
committed
Fix python implementation of IncSubtensor
1 parent f309c22 commit 8f344da

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

pytensor/tensor/subtensor.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,41 +1740,30 @@ def make_node(self, x, y, *inputs):
17401740
def decl_view(self):
17411741
return "PyArrayObject * zview = NULL;"
17421742

1743-
def perform(self, node, inputs, out_):
1744-
(out,) = out_
1745-
x, y = inputs[:2]
1746-
indices = list(reversed(inputs[2:]))
1747-
1748-
def _convert(entry):
1749-
if isinstance(entry, Type):
1750-
return indices.pop()
1751-
elif isinstance(entry, slice):
1752-
return slice(
1753-
_convert(entry.start), _convert(entry.stop), _convert(entry.step)
1743+
def perform(self, node, inputs, output_storage):
1744+
x, y, *flat_indices = inputs
1745+
1746+
flat_indices_iterator = iter(flat_indices)
1747+
indices = tuple(
1748+
(
1749+
next(flat_indices_iterator)
1750+
if isinstance(entry, Type)
1751+
else slice(
1752+
None if entry.start is None else next(flat_indices_iterator),
1753+
None if entry.stop is None else next(flat_indices_iterator),
1754+
None if entry.step is None else next(flat_indices_iterator),
17541755
)
1755-
else:
1756-
return entry
1756+
)
1757+
for entry in self.idx_list
1758+
)
17571759

1758-
cdata = tuple(map(_convert, self.idx_list))
1759-
if len(cdata) == 1:
1760-
cdata = cdata[0]
17611760
if not self.inplace:
17621761
x = x.copy()
1763-
sub_x = x.__getitem__(cdata)
1764-
if sub_x.shape:
1765-
# we've sliced out an N-D tensor with N > 0
1766-
if not self.set_instead_of_inc:
1767-
sub_x += y
1768-
else:
1769-
# sub_x += -sub_x + y
1770-
x.__setitem__(cdata, y)
1762+
if self.set_instead_of_inc:
1763+
x[indices] = y
17711764
else:
1772-
# scalar case
1773-
if not self.set_instead_of_inc:
1774-
x.__setitem__(cdata, sub_x + y)
1775-
else:
1776-
x.__setitem__(cdata, y)
1777-
out[0] = x
1765+
x[indices] += y
1766+
output_storage[0][0] = x
17781767

17791768
def c_code(self, node, name, inputs, outputs, sub):
17801769
# This method delegates much of the work to helper

0 commit comments

Comments
 (0)