Skip to content

Commit 9f7d6dc

Browse files
Copilotjustinchuby
andcommitted
Fix code formatting and pass all linters
Co-authored-by: justinchuby <[email protected]>
1 parent 3ee33f5 commit 9f7d6dc

File tree

1 file changed

+70
-50
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+70
-50
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7285,138 +7285,158 @@ def aten_repeat_interleave(
72857285
repeats: TensorType, output_size: Optional[int] = None
72867286
) -> TensorType:
72877287
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
7288-
7288+
72897289
# Convert repeats to int64 for ONNX compatibility
72907290
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
7291-
7292-
# Create indices [0, 1, 2, ..., len(repeats)-1]
7293-
num_elements = op.Shape(repeats_int64, start=0, end=1)
7294-
indices = op.Range(op.Constant(value_ints=[0]), num_elements, op.Constant(value_ints=[1]))
7295-
7291+
72967292
# Get cumulative sum of repeats to find the boundaries
72977293
cumsum = op.CumSum(repeats_int64, axis=0)
72987294
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
7299-
7295+
73007296
# Create output tensor indices
7301-
output_range = op.Range(op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]))
7302-
7297+
output_range = op.Range(
7298+
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
7299+
)
7300+
73037301
# Find which original index each output position corresponds to
73047302
# We need to find the first cumsum position > each output position
73057303
# This is equivalent to a searchsorted operation
7306-
7304+
73077305
# Expand dimensions for broadcasting
73087306
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
73097307
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7310-
7311-
# Find positions where output_range < cumsum
7308+
7309+
# Find positions where output_range < cumsum
73127310
mask = op.Less(output_range_expanded, cumsum_expanded) # Shape: [total_size, len(repeats)]
7313-
7311+
73147312
# For each row, find the first True position (argmax will do this since True=1, False=0)
73157313
result_indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7316-
7314+
73177315
return result_indices
73187316

73197317

73207318
@torch_op("aten::repeat_interleave.self_Tensor", trace_only=True)
73217319
def aten_repeat_interleave_self_tensor(
7322-
self: TensorType, repeats: TensorType, dim: Optional[int] = None, output_size: Optional[int] = None
7320+
self: TensorType,
7321+
repeats: TensorType,
7322+
dim: Optional[int] = None,
7323+
output_size: Optional[int] = None,
73237324
) -> TensorType:
73247325
"""repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
7325-
7326+
73267327
if dim is None:
73277328
# Flatten the tensor first, then repeat elements
73287329
self_flat = op.Reshape(self, [-1])
7329-
7330+
73307331
# Convert repeats to int64 for ONNX compatibility
73317332
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
7332-
7333+
73337334
# Get cumulative sum of repeats to find the boundaries
73347335
cumsum = op.CumSum(repeats_int64, axis=0)
73357336
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
7336-
7337+
73377338
# Create output tensor indices
7338-
output_range = op.Range(op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]))
7339-
7339+
output_range = op.Range(
7340+
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
7341+
)
7342+
73407343
# Find which original index each output position corresponds to
73417344
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
73427345
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7343-
7344-
# Find positions where output_range < cumsum
7345-
mask = op.Less(output_range_expanded, cumsum_expanded) # Shape: [total_size, len(repeats)]
7346-
7346+
7347+
# Find positions where output_range < cumsum
7348+
mask = op.Less(
7349+
output_range_expanded, cumsum_expanded
7350+
) # Shape: [total_size, len(repeats)]
7351+
73477352
# For each row, find the first True position
73487353
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7349-
7354+
73507355
# Gather elements from the flattened tensor
73517356
result = op.Gather(self_flat, indices, axis=0)
73527357
return result
7353-
7358+
73547359
else:
73557360
# Repeat along specific dimension
73567361
# Convert repeats to int64 for ONNX compatibility
73577362
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
7358-
7363+
73597364
# Get cumulative sum of repeats to find the boundaries
73607365
cumsum = op.CumSum(repeats_int64, axis=0)
73617366
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
7362-
7367+
73637368
# Create output tensor indices for the specified dimension
7364-
output_range = op.Range(op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]))
7365-
7369+
output_range = op.Range(
7370+
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
7371+
)
7372+
73667373
# Find which original index each output position corresponds to
73677374
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
73687375
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7369-
7370-
# Find positions where output_range < cumsum
7371-
mask = op.Less(output_range_expanded, cumsum_expanded) # Shape: [total_size, len(repeats)]
7372-
7376+
7377+
# Find positions where output_range < cumsum
7378+
mask = op.Less(
7379+
output_range_expanded, cumsum_expanded
7380+
) # Shape: [total_size, len(repeats)]
7381+
73737382
# For each row, find the first True position
73747383
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7375-
7384+
73767385
# Gather elements along the specified dimension
73777386
result = op.Gather(self, indices, axis=dim)
73787387
return result
73797388

73807389

73817390
@torch_op("aten::repeat_interleave.self_int", trace_only=True)
73827391
def aten_repeat_interleave_self_int(
7383-
self: TensorType, repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None
7392+
self: TensorType,
7393+
repeats: int,
7394+
dim: Optional[int] = None,
7395+
output_size: Optional[int] = None,
73847396
) -> TensorType:
73857397
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
7386-
7398+
73877399
if dim is None:
73887400
# Flatten the tensor first, then repeat each element 'repeats' times
73897401
self_flat = op.Reshape(self, [-1])
73907402
num_elements = op.Shape(self_flat, start=0, end=1)
7391-
7403+
73927404
# Create indices that repeat each original index 'repeats' times
73937405
# For input [a, b, c] with repeats=2, we want indices [0, 0, 1, 1, 2, 2]
7394-
original_indices = op.Range(op.Constant(value_ints=[0]), num_elements, op.Constant(value_ints=[1]))
7395-
7406+
original_indices = op.Range(
7407+
op.Constant(value_ints=[0]), num_elements, op.Constant(value_ints=[1])
7408+
)
7409+
73967410
# Repeat each index 'repeats' times
73977411
# We can use Tile with appropriate reshaping
73987412
indices_reshaped = op.Unsqueeze(original_indices, [1]) # Shape: [num_elements, 1]
73997413
repeat_pattern = op.Constant(value_ints=[1, repeats])
7400-
repeated_indices = op.Tile(indices_reshaped, repeat_pattern) # Shape: [num_elements, repeats]
7414+
repeated_indices = op.Tile(
7415+
indices_reshaped, repeat_pattern
7416+
) # Shape: [num_elements, repeats]
74017417
final_indices = op.Reshape(repeated_indices, [-1]) # Shape: [num_elements * repeats]
7402-
7418+
74037419
# Gather elements from the flattened tensor
74047420
result = op.Gather(self_flat, final_indices, axis=0)
74057421
return result
7406-
7422+
74077423
else:
74087424
# Repeat along specific dimension
7409-
dim_size = op.Shape(self, start=dim, end=dim+1)
7410-
7425+
dim_size = op.Shape(self, start=dim, end=dim + 1)
7426+
74117427
# Create indices that repeat each original index 'repeats' times
7412-
original_indices = op.Range(op.Constant(value_ints=[0]), dim_size, op.Constant(value_ints=[1]))
7413-
7428+
original_indices = op.Range(
7429+
op.Constant(value_ints=[0]), dim_size, op.Constant(value_ints=[1])
7430+
)
7431+
74147432
# Repeat each index 'repeats' times
74157433
indices_reshaped = op.Unsqueeze(original_indices, [1]) # Shape: [dim_size, 1]
74167434
repeat_pattern = op.Constant(value_ints=[1, repeats])
7417-
repeated_indices = op.Tile(indices_reshaped, repeat_pattern) # Shape: [dim_size, repeats]
7435+
repeated_indices = op.Tile(
7436+
indices_reshaped, repeat_pattern
7437+
) # Shape: [dim_size, repeats]
74187438
final_indices = op.Reshape(repeated_indices, [-1]) # Shape: [dim_size * repeats]
7419-
7439+
74207440
# Gather elements along the specified dimension
74217441
result = op.Gather(self, final_indices, axis=dim)
74227442
return result

0 commit comments

Comments
 (0)