@@ -646,7 +646,12 @@ class Repeat(Op):
646646
647647 __props__ = ("axis" ,)
648648
649- def __init__ (self , axis = None ):
649+ def __init__ (self , axis : int | None = None ):
650+ if axis is not None :
651+ if not isinstance (axis , int ) or axis < 0 :
652+ raise ValueError (
653+ f"Repeat only accepts positive integer axis or None, got { axis } "
654+ )
650655 self .axis = axis
651656
652657 def make_node (self , x , repeats ):
@@ -687,48 +692,53 @@ def make_node(self, x, repeats):
687692 out_shape = list (x .type .shape )
688693 out_shape [self .axis ] = None
689694
690- out_type = TensorType (
691- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692- )
693-
695+ out_type = TensorType (x .dtype , shape = out_shape )
694696 return Apply (self , [x , repeats ], [out_type ()])
695697
696698 def perform (self , node , inputs , output_storage ):
697- x = inputs [0 ]
698- repeats = inputs [1 ]
699- z = output_storage [0 ]
700- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699+ [x , repeats ] = inputs
700+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701701
702702 def connection_pattern (self , node ):
703703 return [[True ], [False ]]
704704
705705 def grad (self , inputs , gout ):
706706 (x , repeats ) = inputs
707707 (gz ,) = gout
708+ axis = self .axis
708709 if repeats .ndim == 0 :
710+ # When axis is a scalar (same number of reps for all elements),
711+ # We can split the repetitions into their own axis with reshape and sum them back
712+ # to the original element location
713+ sum_axis = x .ndim if axis is None else axis + 1
714+ shape = list (x .shape )
715+ shape .insert (sum_axis , repeats )
716+ gx = gz .reshape (shape ).sum (axis = sum_axis )
717+
718+ elif repeats .ndim == 1 :
719+ # To sum the gradients that belong to the same repeated x,
720+ # We create a repeated eye and dot product it with the gradient.
721+ axis_size = x .size if self .axis is None else x .shape [self .axis ]
722+ tiled_eye = repeat (
723+ ptb .eye (axis_size ), repeats , axis = 0
724+ ) # A sparse repeat would be neat
725+
709726 if self .axis is None :
710- axis = x .ndim
727+ gx = gz @ tiled_eye
728+ # Undo the ravelling when axis=None
729+ gx = gx .reshape (x .shape )
711730 else :
712- if self .axis >= 0 :
713- axis = self .axis + 1
714- else :
715- axis = self .axis + x .ndim + 1
716-
717- shape = [x .shape [k ] for k in range (x .ndim )]
718- shape .insert (axis , repeats )
731+ # Place gradient axis at end for dot product
732+ gx = ptb .moveaxis (gz , self .axis , - 1 )
733+ gx = gx @ tiled_eye
734+ # Place gradient back into the correct axis
735+ gx = ptb .moveaxis (gx , - 1 , self .axis )
719736
720- return [
721- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722- DisconnectedType ()(),
723- ]
724- elif repeats .ndim == 1 :
725- # For this implementation, we would need to specify the length
726- # of repeats in order to split gz in the right way to sum
727- # the good part.
728- raise NotImplementedError ()
729737 else :
730738 raise ValueError ()
731739
740+ return [gx , DisconnectedType ()()]
741+
732742 def infer_shape (self , fgraph , node , ins_shapes ):
733743 i0_shapes = ins_shapes [0 ]
734744 repeats = node .inputs [1 ]
@@ -757,76 +767,91 @@ def infer_shape(self, fgraph, node, ins_shapes):
757767 return [out_shape ]
758768
759769
760- def repeat (x , repeats , axis = None ):
761- """Repeat elements of an array .
770+ def repeat (a : "TensorLike" , repeats : TensorLike , axis : int or None ) -> TensorVariable :
771+ """Repeat elements of a tensor .
762772
763- It returns an array which has the same shape as `x`, except along the given
764- `axis`. The `axis` parameter is used to specify the axis along which values
765- are repeated. By default, a flattened version of `x` is used.
773+ See `numpy.repeat` for more information.
766774
767- The number of repetitions for each element is `repeats`. `repeats` is
768- broadcasted to fit the length of the given `axis`.
769775
770776 Parameters
771777 ----------
772- x
773- Input data, tensor variable.
774- repeats
775- int, scalar or tensor variable
778+ a: tensor_like
779+ Input tensor
780+ repeats: tensor_like
781+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776782 axis : int, optional
783+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777784
778- See Also
785+ Returns
786+ -------
787+ repeated_tensor: TensorVariable
788+ Output tensor which as the same shape as a, except along the given axis
789+
790+ Examples
779791 --------
780- tensor.tile
792+
793+ .. testcode::
794+
795+ import pytensor.tensor as pt
796+
797+ a = pt.arange(4).reshape((2, 2))
798+ out = pt.repeat(a, repeats=[2, 3], axis=0)
799+ print(out.eval())
800+
801+ .. testoutput::
802+
803+ [[0 1]
804+ [0 1]
805+ [2 3]
806+ [2 3]
807+ [2 3]]
808+
781809
782810 .. versionadded:: 0.6
783811
784812 """
813+ a = ptb .as_tensor_variable (a )
814+
815+ if axis is not None :
816+ axis = normalize_axis_index (axis , a .ndim )
817+
785818 repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786819
787820 if repeats .ndim > 1 :
788821 raise ValueError ("The dimension of repeats should not exceed 1." )
789822
790823 if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791- return Repeat (axis = axis )(x , repeats )
824+ # We only use the Repeat Op for vector repeats
825+ return Repeat (axis = axis )(a , repeats )
792826 else :
793827 if repeats .ndim == 1 :
794828 repeats = repeats [0 ]
795829
796- if x .dtype == "uint64" :
830+ if a .dtype == "uint64" :
797831 raise TypeError ("repeat doesn't support dtype uint64" )
798832
799833 if axis is None :
800834 axis = 0
801- x = x .flatten ()
802- else :
803- if axis >= x .ndim :
804- raise ValueError ("Axis should not exceed x.ndim-1." )
805- if axis < 0 :
806- axis = x .ndim + axis
835+ a = a .flatten ()
807836
808- shape = [ x . shape [ i ] for i in range ( x . ndim )]
837+ repeat_shape = list ( a . shape )
809838
810- # shape_ is the shape of the intermediate tensor which has
839+ # alloc_shape is the shape of the intermediate tensor which has
811840 # an additional dimension comparing to x. We use alloc to
812841 # allocate space for this intermediate tensor to replicate x
813842 # along that additional dimension.
814- shape_ = shape [:]
815- shape_ .insert (axis + 1 , repeats )
843+ alloc_shape = repeat_shape [:]
844+ alloc_shape .insert (axis + 1 , repeats )
816845
817- # shape is now the shape of output, where shape[axis] becomes
846+ # repeat_shape is now the shape of output, where shape[axis] becomes
818847 # shape[axis]*repeats.
819- shape [axis ] = shape [axis ] * repeats
820-
821- # dims_ is the dimension of that intermediate tensor.
822- dims_ = list (np .arange (x .ndim ))
823- dims_ .insert (axis + 1 , "x" )
848+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824849
825850 # After the original tensor is duplicated along the additional
826- # dimension, we reshape it to the expected output shape, and
827- # return the output z.
828- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829- return z
851+ # dimension, we reshape it to the expected output shape
852+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
853+ repeat_shape
854+ )
830855
831856
832857class Bartlett (Op ):
0 commit comments