@@ -646,12 +646,17 @@ class Repeat(Op):
646646
647647 __props__ = ("axis" ,)
648648
649- def __init__ (self , axis = None ):
649+ def __init__ (self , axis : int | None = None ):
650+ if axis is not None :
651+ if not isinstance (axis , int ) or axis < 0 :
652+ raise ValueError (
653+ f"Repeat only accepts positive integer axis or None, got { axis } "
654+ )
650655 self .axis = axis
651656
652657 def make_node (self , x , repeats ):
653658 x = ptb .as_tensor_variable (x )
654- repeats = ptb .as_tensor_variable (repeats )
659+ repeats = ptb .as_tensor_variable (repeats , dtype = "int64" )
655660
656661 if repeats .dtype not in integer_dtypes :
657662 raise TypeError ("repeats.dtype must be an integer." )
@@ -687,58 +692,64 @@ def make_node(self, x, repeats):
687692 out_shape = list (x .type .shape )
688693 out_shape [self .axis ] = None
689694
690- out_type = TensorType (
691- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692- )
693-
695+ out_type = TensorType (x .dtype , shape = out_shape )
694696 return Apply (self , [x , repeats ], [out_type ()])
695697
696698 def perform (self , node , inputs , output_storage ):
697- x = inputs [0 ]
698- repeats = inputs [1 ]
699- z = output_storage [0 ]
700- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699+ [x , repeats ] = inputs
700+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701701
702702 def connection_pattern (self , node ):
703703 return [[True ], [False ]]
704704
705705 def grad (self , inputs , gout ):
706706 (x , repeats ) = inputs
707707 (gz ,) = gout
708+ axis = self .axis
708709 if repeats .ndim == 0 :
709- if self .axis is None :
710- axis = x .ndim
711- else :
712- if self .axis >= 0 :
713- axis = self .axis + 1
714- else :
715- axis = self .axis + x .ndim + 1
716-
717- shape = [x .shape [k ] for k in range (x .ndim )]
718- shape .insert (axis , repeats )
710+ # When axis is a scalar (same number of reps for all elements),
711+ # We can split the repetitions into their own axis with reshape and sum them back
712+ # to the original element location
713+ sum_axis = x .ndim if axis is None else axis + 1
714+ shape = list (x .shape )
715+ shape .insert (sum_axis , repeats )
716+ gx = gz .reshape (shape ).sum (axis = sum_axis )
719717
720- return [
721- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722- DisconnectedType ()(),
723- ]
724718 elif repeats .ndim == 1 :
725- # For this implementation, we would need to specify the length
726- # of repeats in order to split gz in the right way to sum
727- # the good part.
728- raise NotImplementedError ()
719+ # To sum the gradients that belong to the same repeated x,
720+ # We create a repeated eye and dot product it with the gradient.
721+ axis_size = x .size if axis is None else x .shape [axis ]
722+ repeated_eye = repeat (
723+ ptb .eye (axis_size ), repeats , axis = 0
724+ ) # A sparse repeat would be neat
725+
726+ if axis is None :
727+ gx = gz @ repeated_eye
728+ # Undo the ravelling when axis=None
729+ gx = gx .reshape (x .shape )
730+ else :
731+ # Place gradient axis at end for dot product
732+ gx = ptb .moveaxis (gz , axis , - 1 )
733+ gx = gx @ repeated_eye
734+ # Place gradient back into the correct axis
735+ gx = ptb .moveaxis (gx , - 1 , axis )
736+
729737 else :
730738 raise ValueError ()
731739
740+ return [gx , DisconnectedType ()()]
741+
732742 def infer_shape (self , fgraph , node , ins_shapes ):
733743 i0_shapes = ins_shapes [0 ]
734744 repeats = node .inputs [1 ]
735745 out_shape = list (i0_shapes )
746+ axis = self .axis
736747
737748 # uint64 shape are not supported.
738749 dtype = None
739750 if repeats .dtype in ("uint8" , "uint16" , "uint32" ):
740751 dtype = "int64"
741- if self . axis is None :
752+ if axis is None :
742753 if repeats .ndim == 0 :
743754 if len (i0_shapes ) == 0 :
744755 out_shape = [repeats ]
@@ -751,82 +762,115 @@ def infer_shape(self, fgraph, node, ins_shapes):
751762 out_shape = [pt_sum (repeats , dtype = dtype )]
752763 else :
753764 if repeats .ndim == 0 :
754- out_shape [self . axis ] = out_shape [self . axis ] * repeats
765+ out_shape [axis ] = out_shape [axis ] * repeats
755766 else :
756- out_shape [self . axis ] = pt_sum (repeats , dtype = dtype )
767+ out_shape [axis ] = pt_sum (repeats , dtype = dtype )
757768 return [out_shape ]
758769
759770
760- def repeat (x , repeats , axis = None ):
761- """Repeat elements of an array.
771+ def repeat (
772+ a : TensorLike , repeats : TensorLike , axis : int or None = None
773+ ) -> TensorVariable :
774+ """Repeat elements of a tensor.
762775
763- It returns an array which has the same shape as `x`, except along the given
764- `axis`. The `axis` parameter is used to specify the axis along which values
765- are repeated. By default, a flattened version of `x` is used.
776+ See :func:`numpy.repeat` for more information.
766777
767- The number of repetitions for each element is `repeats`. `repeats` is
768- broadcasted to fit the length of the given `axis`.
769778
770779 Parameters
771780 ----------
772- x
773- Input data, tensor variable.
774- repeats
775- int, scalar or tensor variable
781+ a: tensor_like
782+ Input tensor
783+ repeats: tensor_like
784+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776785 axis : int, optional
786+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777787
778- See Also
788+ Returns
789+ -------
790+ repeated_tensor: TensorVariable
791+ Output tensor which as the same shape as a, except along the given axis
792+
793+ Examples
779794 --------
780- tensor.tile
795+
796+ .. testcode::
797+
798+ import pytensor.tensor as pt
799+
800+ a = pt.arange(4).reshape((2, 2))
801+ out = pt.repeat(a, repeats=[2, 3], axis=0)
802+ print(out.eval())
803+
804+ .. testoutput::
805+
806+ [[0 1]
807+ [0 1]
808+ [2 3]
809+ [2 3]
810+ [2 3]]
811+
812+ When axis is None, the array is first flattened and then repeated
813+
814+ .. testcode::
815+
816+ import pytensor.tensor as pt
817+
818+ a = pt.arange(4).reshape((2, 2))
819+ out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None)
820+ print(out.eval())
821+
822+ .. testoutput::
823+
824+ [0 0 1 1 1 3]
825+
781826
782827 .. versionadded:: 0.6
783828
784829 """
830+ a = ptb .as_tensor_variable (a )
831+
832+ if axis is not None :
833+ axis = normalize_axis_index (axis , a .ndim )
834+
785835 repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786836
787837 if repeats .ndim > 1 :
788838 raise ValueError ("The dimension of repeats should not exceed 1." )
789839
790840 if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791- return Repeat (axis = axis )(x , repeats )
841+ # We only use the Repeat Op for vector repeats
842+ return Repeat (axis = axis )(a , repeats )
792843 else :
793844 if repeats .ndim == 1 :
794845 repeats = repeats [0 ]
795846
796- if x .dtype == "uint64" :
847+ if a .dtype == "uint64" :
848+ # Multiplying int64 (shape) by uint64 (repeats) yields a float64
849+ # Which is not valid for the `reshape` operation at the end
797850 raise TypeError ("repeat doesn't support dtype uint64" )
798851
799852 if axis is None :
800853 axis = 0
801- x = x .flatten ()
802- else :
803- if axis >= x .ndim :
804- raise ValueError ("Axis should not exceed x.ndim-1." )
805- if axis < 0 :
806- axis = x .ndim + axis
854+ a = a .flatten ()
807855
808- shape = [ x . shape [ i ] for i in range ( x . ndim )]
856+ repeat_shape = list ( a . shape )
809857
810- # shape_ is the shape of the intermediate tensor which has
858+ # alloc_shape is the shape of the intermediate tensor which has
811859 # an additional dimension comparing to x. We use alloc to
812860 # allocate space for this intermediate tensor to replicate x
813861 # along that additional dimension.
814- shape_ = shape [:]
815- shape_ .insert (axis + 1 , repeats )
862+ alloc_shape = repeat_shape [:]
863+ alloc_shape .insert (axis + 1 , repeats )
816864
817- # shape is now the shape of output, where shape[axis] becomes
865+ # repeat_shape is now the shape of output, where shape[axis] becomes
818866 # shape[axis]*repeats.
819- shape [axis ] = shape [axis ] * repeats
820-
821- # dims_ is the dimension of that intermediate tensor.
822- dims_ = list (np .arange (x .ndim ))
823- dims_ .insert (axis + 1 , "x" )
867+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824868
825869 # After the original tensor is duplicated along the additional
826- # dimension, we reshape it to the expected output shape, and
827- # return the output z.
828- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829- return z
870+ # dimension, we reshape it to the expected output shape
871+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
872+ repeat_shape
873+ )
830874
831875
832876class Bartlett (Op ):
0 commit comments