Skip to content

Commit 0755eef

Browse files
committed
Implement gradient for vector repetitions
Cleanup documentation
1 parent 4ac1e63 commit 0755eef

File tree

2 files changed

+101
-67
lines changed

2 files changed

+101
-67
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 84 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,12 @@ class Repeat(Op):
646646

647647
__props__ = ("axis",)
648648

649-
def __init__(self, axis=None):
649+
def __init__(self, axis: int | None = None):
650+
if axis is not None:
651+
if not isinstance(axis, int) or axis < 0:
652+
raise ValueError(
653+
f"Repeat only accepts positive integer axis or None, got {axis}"
654+
)
650655
self.axis = axis
651656

652657
def make_node(self, x, repeats):
@@ -687,48 +692,51 @@ def make_node(self, x, repeats):
687692
out_shape = list(x.type.shape)
688693
out_shape[self.axis] = None
689694

690-
out_type = TensorType(
691-
x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape)
692-
)
693-
695+
out_type = TensorType(x.dtype, shape=out_shape)
694696
return Apply(self, [x, repeats], [out_type()])
695697

696698
def perform(self, node, inputs, output_storage):
697-
x = inputs[0]
698-
repeats = inputs[1]
699-
z = output_storage[0]
700-
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
699+
[x, repeats] = inputs
700+
output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis)
701701

702702
def connection_pattern(self, node):
703703
return [[True], [False]]
704704

705705
def grad(self, inputs, gout):
706706
(x, repeats) = inputs
707707
(gz,) = gout
708+
axis = self.axis
708709
if repeats.ndim == 0:
710+
# When axis is a scalar (same number of reps for all elements),
711+
# We can split the repetitions into their own axis with reshape and sum them back
712+
# to the original element location
713+
sum_axis = x.ndim if axis is None else axis + 1
714+
shape = list(x.shape)
715+
shape.insert(sum_axis, repeats)
716+
gx = gz.reshape(shape).sum(axis=sum_axis)
717+
718+
elif repeats.ndim == 1:
719+
# To sum the gradients that belong to the same repeated x,
720+
# We create a repeated eye and dot product it with the gradient.
721+
axis_size = x.size if self.axis is None else x.shape[self.axis]
722+
tiled_eye = repeat(ptb.eye(axis_size), repeats, axis=0)
723+
709724
if self.axis is None:
710-
axis = x.ndim
725+
gx = gz @ tiled_eye
726+
# Undo the ravelling when axis=None
727+
gx = gx.reshape(x.shape)
711728
else:
712-
if self.axis >= 0:
713-
axis = self.axis + 1
714-
else:
715-
axis = self.axis + x.ndim + 1
716-
717-
shape = [x.shape[k] for k in range(x.ndim)]
718-
shape.insert(axis, repeats)
729+
# Place gradient axis at end for dot product
730+
gx = ptb.moveaxis(gz, self.axis, -1)
731+
gx = gx @ tiled_eye
732+
# Place gradient back into the correct axis
733+
gx = ptb.moveaxis(gx, -1, self.axis)
719734

720-
return [
721-
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
722-
DisconnectedType()(),
723-
]
724-
elif repeats.ndim == 1:
725-
# For this implementation, we would need to specify the length
726-
# of repeats in order to split gz in the right way to sum
727-
# the good part.
728-
raise NotImplementedError()
729735
else:
730736
raise ValueError()
731737

738+
return [gx, DisconnectedType()()]
739+
732740
def infer_shape(self, fgraph, node, ins_shapes):
733741
i0_shapes = ins_shapes[0]
734742
repeats = node.inputs[1]
@@ -757,76 +765,91 @@ def infer_shape(self, fgraph, node, ins_shapes):
757765
return [out_shape]
758766

759767

760-
def repeat(x, repeats, axis=None):
761-
"""Repeat elements of an array.
768+
def repeat(a: "TensorLike", repeats: TensorLike, axis: int or None) -> TensorVariable:
769+
"""Repeat elements of a tensor.
762770
763-
It returns an array which has the same shape as `x`, except along the given
764-
`axis`. The `axis` parameter is used to specify the axis along which values
765-
are repeated. By default, a flattened version of `x` is used.
771+
See `numpy.repeat` for more information.
766772
767-
The number of repetitions for each element is `repeats`. `repeats` is
768-
broadcasted to fit the length of the given `axis`.
769773
770774
Parameters
771775
----------
772-
x
773-
Input data, tensor variable.
774-
repeats
775-
int, scalar or tensor variable
776+
a: tensor_like
777+
Input tensor
778+
repeats: tensor_like
779+
The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776780
axis : int, optional
781+
The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777782
778-
See Also
783+
Returns
784+
-------
785+
repeated_tensor: TensorVariable
786+
Output tensor which as the same shape as a, except along the given axis
787+
788+
Examples
779789
--------
780-
tensor.tile
790+
791+
.. testcode::
792+
793+
import pytensor.tensor as pt
794+
795+
a = pt.arange(4).reshape((2, 2))
796+
out = pt.repeat(a, repeats=[2, 3], axis=0)
797+
print(out.eval())
798+
799+
.. testoutput::
800+
801+
[[0 1]
802+
[0 1]
803+
[2 3]
804+
[2 3]
805+
[2 3]]
806+
781807
782808
.. versionadded:: 0.6
783809
784810
"""
811+
a = ptb.as_tensor_variable(a)
812+
813+
if axis is not None:
814+
axis = normalize_axis_index(axis, a.ndim)
815+
785816
repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)
786817

787818
if repeats.ndim > 1:
788819
raise ValueError("The dimension of repeats should not exceed 1.")
789820

790821
if repeats.ndim == 1 and not repeats.broadcastable[0]:
791-
return Repeat(axis=axis)(x, repeats)
822+
# We only use the Repeat Op for vector repeats
823+
return Repeat(axis=axis)(a, repeats)
792824
else:
793825
if repeats.ndim == 1:
794826
repeats = repeats[0]
795827

796-
if x.dtype == "uint64":
828+
if a.dtype == "uint64":
797829
raise TypeError("repeat doesn't support dtype uint64")
798830

799831
if axis is None:
800832
axis = 0
801-
x = x.flatten()
802-
else:
803-
if axis >= x.ndim:
804-
raise ValueError("Axis should not exceed x.ndim-1.")
805-
if axis < 0:
806-
axis = x.ndim + axis
833+
a = a.flatten()
807834

808-
shape = [x.shape[i] for i in range(x.ndim)]
835+
repeat_shape = list(a.shape)
809836

810-
# shape_ is the shape of the intermediate tensor which has
837+
# alloc_shape is the shape of the intermediate tensor which has
811838
# an additional dimension comparing to x. We use alloc to
812839
# allocate space for this intermediate tensor to replicate x
813840
# along that additional dimension.
814-
shape_ = shape[:]
815-
shape_.insert(axis + 1, repeats)
841+
alloc_shape = repeat_shape[:]
842+
alloc_shape.insert(axis + 1, repeats)
816843

817-
# shape is now the shape of output, where shape[axis] becomes
844+
# repeat_shape is now the shape of output, where shape[axis] becomes
818845
# shape[axis]*repeats.
819-
shape[axis] = shape[axis] * repeats
820-
821-
# dims_ is the dimension of that intermediate tensor.
822-
dims_ = list(np.arange(x.ndim))
823-
dims_.insert(axis + 1, "x")
846+
repeat_shape[axis] = repeat_shape[axis] * repeats
824847

825848
# After the original tensor is duplicated along the additional
826-
# dimension, we reshape it to the expected output shape, and
827-
# return the output z.
828-
z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
829-
return z
849+
# dimension, we reshape it to the expected output shape
850+
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
851+
repeat_shape
852+
)
830853

831854

832855
class Bartlett(Op):

tests/tensor/test_extra_ops.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,12 +635,23 @@ def test_infer_shape(self, ndim, dtype):
635635
self.op_class,
636636
)
637637

638-
@pytest.mark.parametrize("ndim", range(3))
639-
def test_grad(self, ndim):
640-
a = np.random.random((10,) * ndim).astype(config.floatX)
641-
642-
for axis in self._possible_axis(ndim):
643-
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
638+
@pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}")
639+
@pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}")
640+
@pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}")
641+
def test_grad(self, x_ndim, repeats_ndim, axis):
642+
rng = np.random.default_rng(
643+
[653, x_ndim, 2 if axis is None else axis, repeats_ndim]
644+
)
645+
x_test = rng.normal(size=np.arange(3, 3 + x_ndim))
646+
if repeats_ndim == 0:
647+
repeats_size = ()
648+
else:
649+
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
650+
repeats = rng.integers(1, 6, size=repeats_size)
651+
utt.verify_grad(
652+
lambda x: Repeat(axis=axis)(x, repeats),
653+
[x_test],
654+
)
644655

645656
def test_broadcastable(self):
646657
x = TensorType(config.floatX, shape=(None, 1, None))()

0 commit comments

Comments
 (0)