Skip to content

Commit cb8b8ac

Browse files
committed
Change the semantics of the set and add helpers
1 parent c1bceb9 commit cb8b8ac

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

pytensor/tensor/variable.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -820,35 +820,36 @@ def compress(self, a, axis=None):
820820
"""Return selected slices only."""
821821
return at.extra_ops.compress(self, a, axis=axis)
822822

823-
def set(self, y, **kwargs):
824-
"""Return a copy of a tensor with the indexed values set to y.
823+
def set(self, idx, y, **kwargs):
824+
"""Return a copy of self with the indexed values set to y.
825825
826-
Self must be the output of an indexing operation.
827-
828-
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
826+
Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
829827
830828
Examples
831829
--------
832-
833-
>>> x = matrix()
834-
>>> out = x[0].set(5)
830+
>>> import pytensor.tensor as pt
831+
>>>
832+
>>> x = pt.ones((3,))
833+
>>> out = x.set(1, 2)
834+
>>> out.eval() # array([1., 2., 1.])
835835
"""
836-
return at.subtensor.set_subtensor(self, y, **kwargs)
837-
838-
def add(self, y, **kwargs):
839-
"""Return a copy of a tensor with the indexed values incremented by y.
836+
return at.subtensor.set_subtensor(self[idx], y, **kwargs)
840837

841-
Self must be the output of an indexing operation.
838+
def add(self, idx, y, **kwargs):
839+
"""Return a copy of self with the indexed values incremented by y.
842840
843-
Equivalent to inc_subtensor(self, y). See docstrings for kwargs.
841+
Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
844842
845843
Examples
846844
--------
847845
848-
>>> x = matrix()
849-
>>> out = x[0].add(5)
846+
>>> import pytensor.tensor as pt
847+
>>>
848+
>>> x = pt.ones((3,))
849+
>>> out = x.add(1, 2)
850+
>>> out.eval() # array([1., 3., 1.])
850851
"""
851-
return at.inc_subtensor(self, y, **kwargs)
852+
return at.inc_subtensor(self[idx], y, **kwargs)
852853

853854

854855
class TensorVariable(

tests/tensor/test_variable.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -438,14 +438,8 @@ def test_set_add(self):
438438
idx = [0]
439439
y = 5
440440

441-
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
442-
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])
443-
444-
msg = "must be the result of a subtensor operation"
445-
with pytest.raises(TypeError, match=msg):
446-
x.set(y)
447-
with pytest.raises(TypeError, match=msg):
448-
x.add(y)
441+
assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
442+
assert equal_computations([x.add(idx, y)], [inc_subtensor(x[idx], y)])
449443

450444
def test_set_item_error(self):
451445
x = matrix("x")

0 commit comments

Comments
 (0)