@@ -660,53 +660,72 @@ class Repeat(Op):
660660
661661 __props__ = ("axis" ,)
662662
663- def __init__ (self , axis : int | None = None ):
664- if axis is not None :
665- if not isinstance ( axis , int ) or axis < 0 :
663+ def __init__ (self , axis : int ):
664+ if isinstance ( axis , int ) :
665+ if axis < 0 :
666666 raise ValueError (
667- f"Repeat only accepts positive integer axis or None, got { axis } "
667+ f"Repeat Op only accepts positive integer axis, got { axis } . "
668+ "Use the helper `pt.repeat` to handle negative axis."
668669 )
670+ elif axis is None :
671+ raise ValueError (
672+ "Repeat Op only accepts positive integer axis. "
673+ "Use the helper `pt.repeat` to handle axis=None."
674+ )
675+ else :
676+ raise TypeError (
677+ f"Invalid type for axis { axis } , expected int got { type (axis )} "
678+ )
679+
669680 self .axis = axis
670681
671682 def make_node (self , x , repeats ):
672683 x = ptb .as_tensor_variable (x )
673684 repeats = ptb .as_tensor_variable (repeats , dtype = "int64" )
674685
675- if repeats .dtype not in integer_dtypes :
676- raise TypeError ("repeats.dtype must be an integer." )
686+ if repeats .type .ndim != 1 :
687+ if repeats .type .ndim == 0 :
688+ raise ValueError (
689+ f"repeats { repeats } must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
690+ )
691+ else :
692+ raise ValueError (
693+ f"repeats { repeats } must have 1 dimension, got { repeats .type .ndim } "
694+ )
695+
696+ if repeats .type .dtype not in integer_dtypes :
697+ raise TypeError (
698+ f"repeats { repeats } dtype must be an integer, got { repeats .type .dtype } ."
699+ )
677700
678701 # Some dtypes are not supported by numpy's implementation of repeat.
679702 # Until another one is available, we should fail at graph construction
680703 # time, not wait for execution.
681- ptr_bitwidth = LOCAL_BITWIDTH
682- if ptr_bitwidth == 64 :
683- numpy_unsupported_dtypes = ("uint64" ,)
684- if ptr_bitwidth == 32 :
685- numpy_unsupported_dtypes = ("uint32" , "int64" , "uint64" )
686-
687- if repeats .dtype in numpy_unsupported_dtypes :
704+ numpy_unsupported_dtypes = (
705+ ("uint64" ,) if LOCAL_BITWIDTH == 64 else ("uint64" , "uint32" , "int64" )
706+ )
707+ if repeats .type .dtype in numpy_unsupported_dtypes :
688708 raise TypeError (
689- (
690- f"dtypes { numpy_unsupported_dtypes !s} are not supported by numpy.repeat "
691- "for the 'repeats' parameter, "
692- ),
693- repeats .dtype ,
709+ f"repeats { repeats } dtype { repeats .type .dtype } are not supported by numpy.repeat"
694710 )
695711
696- if self .axis is None :
697- out_shape = [None ]
698- else :
712+ shape = list (x .type .shape )
713+ axis_input_dim_length = shape [self .axis ]
714+ axis_output_dim_length = None
715+
716+ if axis_input_dim_length is not None :
717+ # If we have a static dim and constant repeats we can infer the length of the output dim
718+ # Right now we only support homogenous constant repeats
699719 try :
700- const_reps = ptb .get_scalar_constant_value (repeats )
720+ const_reps = ptb .get_underlying_scalar_constant_value (repeats )
701721 except NotScalarConstantError :
702- const_reps = None
703- if const_reps == 1 :
704- out_shape = x .type .shape
722+ pass
705723 else :
706- out_shape = list (x .type .shape )
707- out_shape [self .axis ] = None
724+ axis_output_dim_length = int (const_reps * axis_input_dim_length )
725+
726+ shape [self .axis ] = axis_output_dim_length
708727
709- out_type = TensorType (x .dtype , shape = out_shape )
728+ out_type = TensorType (x .dtype , shape = shape )
710729 return Apply (self , [x , repeats ], [out_type ()])
711730
712731 def perform (self , node , inputs , output_storage ):
@@ -720,36 +739,19 @@ def grad(self, inputs, gout):
720739 (x , repeats ) = inputs
721740 (gz ,) = gout
722741 axis = self .axis
723- if repeats .ndim == 0 :
724- # When axis is a scalar (same number of reps for all elements),
725- # We can split the repetitions into their own axis with reshape and sum them back
726- # to the original element location
727- sum_axis = x .ndim if axis is None else axis + 1
728- shape = list (x .shape )
729- shape .insert (sum_axis , repeats )
730- gx = gz .reshape (shape ).sum (axis = sum_axis )
731-
732- elif repeats .ndim == 1 :
733- # To sum the gradients that belong to the same repeated x,
734- # We create a repeated eye and dot product it with the gradient.
735- axis_size = x .size if axis is None else x .shape [axis ]
736- repeated_eye = repeat (
737- ptb .eye (axis_size ), repeats , axis = 0
738- ) # A sparse repeat would be neat
739-
740- if axis is None :
741- gx = gz @ repeated_eye
742- # Undo the ravelling when axis=None
743- gx = gx .reshape (x .shape )
744- else :
745- # Place gradient axis at end for dot product
746- gx = ptb .moveaxis (gz , axis , - 1 )
747- gx = gx @ repeated_eye
748- # Place gradient back into the correct axis
749- gx = ptb .moveaxis (gx , - 1 , axis )
750742
751- else :
752- raise ValueError ()
743+ # To sum the gradients that belong to the same repeated x,
744+ # We create a repeated eye and dot product it with the gradient.
745+ axis_size = x .shape [axis ]
746+ repeated_eye = repeat (
747+ ptb .eye (axis_size ), repeats , axis = 0
748+ ) # A sparse repeat would be neat
749+
750+ # Place gradient axis at end for dot product
751+ gx = ptb .moveaxis (gz , axis , - 1 )
752+ gx = gx @ repeated_eye
753+ # Place gradient back into the correct axis
754+ gx = ptb .moveaxis (gx , - 1 , axis )
753755
754756 return [gx , DisconnectedType ()()]
755757
@@ -763,22 +765,8 @@ def infer_shape(self, fgraph, node, ins_shapes):
763765 dtype = None
764766 if repeats .dtype in ("uint8" , "uint16" , "uint32" ):
765767 dtype = "int64"
766- if axis is None :
767- if repeats .ndim == 0 :
768- if len (i0_shapes ) == 0 :
769- out_shape = [repeats ]
770- else :
771- res = 1
772- for d in i0_shapes :
773- res = res * d
774- out_shape = (res * repeats ,)
775- else :
776- out_shape = [pt_sum (repeats , dtype = dtype )]
777- else :
778- if repeats .ndim == 0 :
779- out_shape [axis ] = out_shape [axis ] * repeats
780- else :
781- out_shape [axis ] = pt_sum (repeats , dtype = dtype )
768+
769+ out_shape [axis ] = pt_sum (repeats , dtype = dtype )
782770 return [out_shape ]
783771
784772
@@ -843,48 +831,42 @@ def repeat(
843831 """
844832 a = ptb .as_tensor_variable (a )
845833
846- if axis is not None :
834+ if axis is None :
835+ axis = 0
836+ a = a .flatten ()
837+ else :
847838 axis = normalize_axis_index (axis , a .ndim )
848839
849840 repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
850841
851842 if repeats .ndim > 1 :
852843 raise ValueError ("The dimension of repeats should not exceed 1." )
853844
854- if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
845+ if repeats .type .broadcastable == (True ,):
846+ # This behaves the same as scalar repeat
847+ repeats = repeats .squeeze ()
848+
849+ if repeats .ndim == 1 :
855850 # We only use the Repeat Op for vector repeats
856851 return Repeat (axis = axis )(a , repeats )
857852 else :
858- if repeats .ndim == 1 :
859- repeats = repeats [0 ]
860-
861853 if a .dtype == "uint64" :
862854 # Multiplying int64 (shape) by uint64 (repeats) yields a float64
863855 # Which is not valid for the `reshape` operation at the end
864856 raise TypeError ("repeat doesn't support dtype uint64" )
865857
866- if axis is None :
867- axis = 0
868- a = a .flatten ()
869-
870- repeat_shape = list (a .shape )
858+ # Scalar repeat, we implement this with canonical Ops broadcast + reshape
859+ a_shape = a .shape
871860
872- # alloc_shape is the shape of the intermediate tensor which has
873- # an additional dimension comparing to x. We use alloc to
874- # allocate space for this intermediate tensor to replicate x
875- # along that additional dimension.
876- alloc_shape = repeat_shape [:]
877- alloc_shape .insert (axis + 1 , repeats )
861+ # Replicate a along a new axis (axis+1) repeats times
862+ broadcast_shape = list (a_shape )
863+ broadcast_shape .insert (axis + 1 , repeats )
864+ broadcast_a = broadcast_to (ptb .expand_dims (a , axis + 1 ), broadcast_shape )
878865
879- # repeat_shape is now the shape of output, where shape[ axis] becomes
880- # shape[axis]*repeats.
866+ # Reshape broadcast_a to the final shape, merging axis and axis+1
867+ repeat_shape = list ( a_shape )
881868 repeat_shape [axis ] = repeat_shape [axis ] * repeats
882-
883- # After the original tensor is duplicated along the additional
884- # dimension, we reshape it to the expected output shape
885- return ptb .alloc (ptb .expand_dims (a , axis + 1 ), * alloc_shape ).reshape (
886- repeat_shape
887- )
869+ return broadcast_a .reshape (repeat_shape )
888870
889871
890872class Bartlett (Op ):
0 commit comments