@@ -646,7 +646,12 @@ class Repeat(Op):
646646
647647 __props__ = ("axis" ,)
648648
649- def __init__ (self , axis = None ):
649+ def __init__ (self , axis : int | None = None ):
650+ if axis is not None :
651+ if not isinstance (axis , int ) or axis < 0 :
652+ raise ValueError (
653+ f"Repeat only accepts positive integer axis or None, got { axis } "
654+ )
650655 self .axis = axis
651656
652657 def make_node (self , x , repeats ):
@@ -687,48 +692,51 @@ def make_node(self, x, repeats):
687692 out_shape = list (x .type .shape )
688693 out_shape [self .axis ] = None
689694
690- out_type = TensorType (
691- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692- )
693-
695+ out_type = TensorType (x .dtype , shape = out_shape )
694696 return Apply (self , [x , repeats ], [out_type ()])
695697
696698 def perform (self , node , inputs , output_storage ):
697- x = inputs [0 ]
698- repeats = inputs [1 ]
699- z = output_storage [0 ]
700- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699+ [x , repeats ] = inputs
700+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701701
702702 def connection_pattern (self , node ):
703703 return [[True ], [False ]]
704704
705705 def grad (self , inputs , gout ):
706706 (x , repeats ) = inputs
707707 (gz ,) = gout
708+ axis = self .axis
708709 if repeats .ndim == 0 :
710+ # When axis is a scalar (same number of reps for all elements),
711+ # We can split the repetitions into their own axis with reshape and sum them back
712+ # to the original element location
713+ sum_axis = x .ndim if axis is None else axis + 1
714+ shape = list (x .shape )
715+ shape .insert (sum_axis , repeats )
716+ gx = gz .reshape (shape ).sum (axis = sum_axis )
717+
718+ elif repeats .ndim == 1 :
719+ # To sum the gradients that belong to the same repeated x,
720+ # We create a repeated eye and dot product it with the gradient.
721+ axis_size = x .size if self .axis is None else x .shape [self .axis ]
722+ tiled_eye = repeat (ptb .eye (axis_size ), repeats , axis = 0 )
723+
709724 if self .axis is None :
710- axis = x .ndim
725+ gx = gz @ tiled_eye
726+ # Undo the ravelling when axis=None
727+ gx = gx .reshape (x .shape )
711728 else :
712- if self .axis >= 0 :
713- axis = self .axis + 1
714- else :
715- axis = self .axis + x .ndim + 1
716-
717- shape = [x .shape [k ] for k in range (x .ndim )]
718- shape .insert (axis , repeats )
729+ # Place gradient axis at end for dot product
730+ gx = ptb .moveaxis (gz , self .axis , - 1 )
731+ gx = gx @ tiled_eye
732+ # Place gradient back into the correct axis
733+ gx = ptb .moveaxis (gx , - 1 , self .axis )
719734
720- return [
721- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722- DisconnectedType ()(),
723- ]
724- elif repeats .ndim == 1 :
725- # For this implementation, we would need to specify the length
726- # of repeats in order to split gz in the right way to sum
727- # the good part.
728- raise NotImplementedError ()
729735 else :
730736 raise ValueError ()
731737
738+ return [gx , DisconnectedType ()()]
739+
732740 def infer_shape (self , fgraph , node , ins_shapes ):
733741 i0_shapes = ins_shapes [0 ]
734742 repeats = node .inputs [1 ]
@@ -757,76 +765,91 @@ def infer_shape(self, fgraph, node, ins_shapes):
757765 return [out_shape ]
758766
759767
760- def repeat (x , repeats , axis = None ):
761- """Repeat elements of an array .
768+ def repeat (a : "TensorLike" , repeats : TensorLike , axis : int or None ) -> TensorVariable :
769+ """Repeat elements of a tensor .
762770
763- It returns an array which has the same shape as `x`, except along the given
764- `axis`. The `axis` parameter is used to specify the axis along which values
765- are repeated. By default, a flattened version of `x` is used.
771+ See `numpy.repeat` for more information.
766772
767- The number of repetitions for each element is `repeats`. `repeats` is
768- broadcasted to fit the length of the given `axis`.
769773
770774 Parameters
771775 ----------
772- x
773- Input data, tensor variable.
774- repeats
775- int, scalar or tensor variable
776+ a: tensor_like
777+ Input tensor
778+ repeats: tensor_like
779+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776780 axis : int, optional
781+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777782
778- See Also
783+ Returns
784+ -------
785+ repeated_tensor: TensorVariable
786+ Output tensor which as the same shape as a, except along the given axis
787+
788+ Examples
779789 --------
780- tensor.tile
790+
791+ .. testcode::
792+
793+ import pytensor.tensor as pt
794+
795+ a = pt.arange(4).reshape((2, 2))
796+ out = pt.repeat(a, repeats=[2, 3], axis=0)
797+ print(out.eval())
798+
799+ .. testoutput::
800+
801+ [[0 1]
802+ [0 1]
803+ [2 3]
804+ [2 3]
805+ [2 3]]
806+
781807
782808 .. versionadded:: 0.6
783809
784810 """
811+ a = ptb .as_tensor_variable (a )
812+
813+ if axis is not None :
814+ axis = normalize_axis_index (axis , a .ndim )
815+
785816 repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786817
787818 if repeats .ndim > 1 :
788819 raise ValueError ("The dimension of repeats should not exceed 1." )
789820
790821 if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791- return Repeat (axis = axis )(x , repeats )
822+ # We only use the Repeat Op for vector repeats
823+ return Repeat (axis = axis )(a , repeats )
792824 else :
793825 if repeats .ndim == 1 :
794826 repeats = repeats [0 ]
795827
796- if x .dtype == "uint64" :
828+ if a .dtype == "uint64" :
797829 raise TypeError ("repeat doesn't support dtype uint64" )
798830
799831 if axis is None :
800832 axis = 0
801- x = x .flatten ()
802- else :
803- if axis >= x .ndim :
804- raise ValueError ("Axis should not exceed x.ndim-1." )
805- if axis < 0 :
806- axis = x .ndim + axis
833+ a = a .flatten ()
807834
808- shape = [ x . shape [ i ] for i in range ( x . ndim )]
835+ repeat_shape = list ( a . shape )
809836
810- # shape_ is the shape of the intermediate tensor which has
837+ # alloc_shape is the shape of the intermediate tensor which has
811838 # an additional dimension comparing to x. We use alloc to
812839 # allocate space for this intermediate tensor to replicate x
813840 # along that additional dimension.
814- shape_ = shape [:]
815- shape_ .insert (axis + 1 , repeats )
841+ alloc_shape = repeat_shape [:]
842+ alloc_shape .insert (axis + 1 , repeats )
816843
817- # shape is now the shape of output, where shape[axis] becomes
844+ # repeat_shape is now the shape of output, where shape[axis] becomes
818845 # shape[axis]*repeats.
819- shape [axis ] = shape [axis ] * repeats
820-
821- # dims_ is the dimension of that intermediate tensor.
822- dims_ = list (np .arange (x .ndim ))
823- dims_ .insert (axis + 1 , "x" )
846+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824847
825848 # After the original tensor is duplicated along the additional
826- # dimension, we reshape it to the expected output shape, and
827- # return the output z.
828- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829- return z
849+ # dimension, we reshape it to the expected output shape
850+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
851+ repeat_shape
852+ )
830853
831854
832855class Bartlett (Op ):
0 commit comments