Skip to content

Commit 9d5a825

Browse files
committed
Add set and add TensorVariable methods for set_subtensor and inc_subtensor operations
1 parent a6f3f2d commit 9d5a825

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

pytensor/tensor/variable.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,32 @@ def compress(self, a, axis=None):
815815
"""Return selected slices only."""
816816
return at.extra_ops.compress(self, a, axis=axis)
817817

818+
def set(self, y, **kwargs):
819+
"""Set values to y, where y is the output of an index operation.
820+
821+
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
822+
823+
Examples
824+
--------
825+
826+
>>> x = matrix()
827+
>>> out = x[0].set(5)
828+
"""
829+
return at.subtensor.set_subtensor(self, y, **kwargs)
830+
831+
def add(self, y, **kwargs):
832+
"""Add values to y, where y is the output of an index operation.
833+
834+
Equivalent to inc_subtensor(self, y). See docstrings for kwargs
835+
836+
Examples
837+
--------
838+
839+
>>> x = matrix()
840+
>>> out = x[0].add(5)
841+
"""
842+
return at.inc_subtensor(self, y, **kwargs)
843+
818844

819845
class TensorVariable(
820846
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]

tests/tensor/test_variable.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from pytensor.tensor.elemwise import DimShuffle
1515
from pytensor.tensor.math import dot, eq, matmul
1616
from pytensor.tensor.shape import Shape
17-
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
17+
from pytensor.tensor.subtensor import (
18+
AdvancedSubtensor,
19+
Subtensor,
20+
inc_subtensor,
21+
set_subtensor,
22+
)
1823
from pytensor.tensor.type import (
1924
TensorType,
2025
cscalar,
@@ -428,6 +433,20 @@ def test_take(self):
428433
# Test equivalent advanced indexing
429434
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
430435

436+
def test_set_add(self):
437+
x = matrix("x")
438+
idx = [0]
439+
y = 5
440+
441+
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
442+
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])
443+
444+
msg = "must be the result of a subtensor operation"
445+
with pytest.raises(TypeError, match=msg):
446+
x.set(y)
447+
with pytest.raises(TypeError, match=msg):
448+
x.add(y)
449+
431450

432451
def test_deprecated_import():
433452
with pytest.warns(

0 commit comments

Comments
 (0)