@@ -7285,138 +7285,158 @@ def aten_repeat_interleave(
72857285 repeats : TensorType , output_size : Optional [int ] = None
72867286) -> TensorType :
72877287 """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
7288-
7288+
72897289 # Convert repeats to int64 for ONNX compatibility
72907290 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
7291-
7292- # Create indices [0, 1, 2, ..., len(repeats)-1]
7293- num_elements = op .Shape (repeats_int64 , start = 0 , end = 1 )
7294- indices = op .Range (op .Constant (value_ints = [0 ]), num_elements , op .Constant (value_ints = [1 ]))
7295-
7291+
72967292 # Get cumulative sum of repeats to find the boundaries
72977293 cumsum = op .CumSum (repeats_int64 , axis = 0 )
72987294 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
7299-
7295+
73007296 # Create output tensor indices
7301- output_range = op .Range (op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ]))
7302-
7297+ output_range = op .Range (
7298+ op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
7299+ )
7300+
73037301 # Find which original index each output position corresponds to
73047302 # We need to find the first cumsum position > each output position
73057303 # This is equivalent to a searchsorted operation
7306-
7304+
73077305 # Expand dimensions for broadcasting
73087306 cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
73097307 output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7310-
7311- # Find positions where output_range < cumsum
7308+
7309+ # Find positions where output_range < cumsum
73127310 mask = op .Less (output_range_expanded , cumsum_expanded ) # Shape: [total_size, len(repeats)]
7313-
7311+
73147312 # For each row, find the first True position (argmax will do this since True=1, False=0)
73157313 result_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7316-
7314+
73177315 return result_indices
73187316
73197317
73207318@torch_op ("aten::repeat_interleave.self_Tensor" , trace_only = True )
73217319def aten_repeat_interleave_self_tensor (
7322- self : TensorType , repeats : TensorType , dim : Optional [int ] = None , output_size : Optional [int ] = None
7320+ self : TensorType ,
7321+ repeats : TensorType ,
7322+ dim : Optional [int ] = None ,
7323+ output_size : Optional [int ] = None ,
73237324) -> TensorType :
73247325 """repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
7325-
7326+
73267327 if dim is None :
73277328 # Flatten the tensor first, then repeat elements
73287329 self_flat = op .Reshape (self , [- 1 ])
7329-
7330+
73307331 # Convert repeats to int64 for ONNX compatibility
73317332 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
7332-
7333+
73337334 # Get cumulative sum of repeats to find the boundaries
73347335 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73357336 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
7336-
7337+
73377338 # Create output tensor indices
7338- output_range = op .Range (op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ]))
7339-
7339+ output_range = op .Range (
7340+ op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
7341+ )
7342+
73407343 # Find which original index each output position corresponds to
73417344 cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
73427345 output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7343-
7344- # Find positions where output_range < cumsum
7345- mask = op .Less (output_range_expanded , cumsum_expanded ) # Shape: [total_size, len(repeats)]
7346-
7346+
7347+ # Find positions where output_range < cumsum
7348+ mask = op .Less (
7349+ output_range_expanded , cumsum_expanded
7350+ ) # Shape: [total_size, len(repeats)]
7351+
73477352 # For each row, find the first True position
73487353 indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7349-
7354+
73507355 # Gather elements from the flattened tensor
73517356 result = op .Gather (self_flat , indices , axis = 0 )
73527357 return result
7353-
7358+
73547359 else :
73557360 # Repeat along specific dimension
73567361 # Convert repeats to int64 for ONNX compatibility
73577362 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
7358-
7363+
73597364 # Get cumulative sum of repeats to find the boundaries
73607365 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73617366 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
7362-
7367+
73637368 # Create output tensor indices for the specified dimension
7364- output_range = op .Range (op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ]))
7365-
7369+ output_range = op .Range (
7370+ op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
7371+ )
7372+
73667373 # Find which original index each output position corresponds to
73677374 cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
73687375 output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7369-
7370- # Find positions where output_range < cumsum
7371- mask = op .Less (output_range_expanded , cumsum_expanded ) # Shape: [total_size, len(repeats)]
7372-
7376+
7377+ # Find positions where output_range < cumsum
7378+ mask = op .Less (
7379+ output_range_expanded , cumsum_expanded
7380+ ) # Shape: [total_size, len(repeats)]
7381+
73737382 # For each row, find the first True position
73747383 indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7375-
7384+
73767385 # Gather elements along the specified dimension
73777386 result = op .Gather (self , indices , axis = dim )
73787387 return result
73797388
73807389
73817390@torch_op ("aten::repeat_interleave.self_int" , trace_only = True )
73827391def aten_repeat_interleave_self_int (
7383- self : TensorType , repeats : int , dim : Optional [int ] = None , output_size : Optional [int ] = None
7392+ self : TensorType ,
7393+ repeats : int ,
7394+ dim : Optional [int ] = None ,
7395+ output_size : Optional [int ] = None ,
73847396) -> TensorType :
73857397 """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
7386-
7398+
73877399 if dim is None :
73887400 # Flatten the tensor first, then repeat each element 'repeats' times
73897401 self_flat = op .Reshape (self , [- 1 ])
73907402 num_elements = op .Shape (self_flat , start = 0 , end = 1 )
7391-
7403+
73927404 # Create indices that repeat each original index 'repeats' times
73937405 # For input [a, b, c] with repeats=2, we want indices [0, 0, 1, 1, 2, 2]
7394- original_indices = op .Range (op .Constant (value_ints = [0 ]), num_elements , op .Constant (value_ints = [1 ]))
7395-
7406+ original_indices = op .Range (
7407+ op .Constant (value_ints = [0 ]), num_elements , op .Constant (value_ints = [1 ])
7408+ )
7409+
73967410 # Repeat each index 'repeats' times
73977411 # We can use Tile with appropriate reshaping
73987412 indices_reshaped = op .Unsqueeze (original_indices , [1 ]) # Shape: [num_elements, 1]
73997413 repeat_pattern = op .Constant (value_ints = [1 , repeats ])
7400- repeated_indices = op .Tile (indices_reshaped , repeat_pattern ) # Shape: [num_elements, repeats]
7414+ repeated_indices = op .Tile (
7415+ indices_reshaped , repeat_pattern
7416+ ) # Shape: [num_elements, repeats]
74017417 final_indices = op .Reshape (repeated_indices , [- 1 ]) # Shape: [num_elements * repeats]
7402-
7418+
74037419 # Gather elements from the flattened tensor
74047420 result = op .Gather (self_flat , final_indices , axis = 0 )
74057421 return result
7406-
7422+
74077423 else :
74087424 # Repeat along specific dimension
7409- dim_size = op .Shape (self , start = dim , end = dim + 1 )
7410-
7425+ dim_size = op .Shape (self , start = dim , end = dim + 1 )
7426+
74117427 # Create indices that repeat each original index 'repeats' times
7412- original_indices = op .Range (op .Constant (value_ints = [0 ]), dim_size , op .Constant (value_ints = [1 ]))
7413-
7428+ original_indices = op .Range (
7429+ op .Constant (value_ints = [0 ]), dim_size , op .Constant (value_ints = [1 ])
7430+ )
7431+
74147432 # Repeat each index 'repeats' times
74157433 indices_reshaped = op .Unsqueeze (original_indices , [1 ]) # Shape: [dim_size, 1]
74167434 repeat_pattern = op .Constant (value_ints = [1 , repeats ])
7417- repeated_indices = op .Tile (indices_reshaped , repeat_pattern ) # Shape: [dim_size, repeats]
7435+ repeated_indices = op .Tile (
7436+ indices_reshaped , repeat_pattern
7437+ ) # Shape: [dim_size, repeats]
74187438 final_indices = op .Reshape (repeated_indices , [- 1 ]) # Shape: [dim_size * repeats]
7419-
7439+
74207440 # Gather elements along the specified dimension
74217441 result = op .Gather (self , final_indices , axis = dim )
74227442 return result
0 commit comments