Skip to content

Commit 0c0863a

Browse files
authored
Merge branch 'main' into copilot/make-files-internal-private
2 parents 344955b + 20a99d1 commit 0c0863a

File tree

7 files changed

+140
-81
lines changed

7 files changed

+140
-81
lines changed

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]:
2323
for op in registration.default_registry.values():
24-
for func in (*op.overloads, *op.privates, *op.complex):
24+
for func in (*op.overloads, *op.complex):
2525
if isinstance(func, onnxscript.OnnxFunction):
2626
yield func
2727

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 83 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,7 @@ def aten_embedding_bag(
31513151
sparse: bool = False,
31523152
per_sample_weights: Optional[TFloat] = None,
31533153
include_last_offset: bool = False,
3154+
padding_idx: Optional[int] = None,
31543155
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
31553156
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
31563157

@@ -3247,23 +3248,24 @@ def _aten_embedding_bag_onnx(
32473248

32483249
# Only compute the shape of other 3 outputs, we don't care the value
32493250
if mode == 0: # sum
3250-
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
3251+
offset2bag = op.Cast(op.Shape(indices, start=0, end=0), to=INT64.dtype)
32513252
if op.Equal(include_last_offset, True):
3252-
bag_size = op.Expand(0, op.Shape(offsets))
3253+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
3254+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
32533255
else:
3254-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3255-
max_indices = op.Expand(0, op.Shape(bag_size))
3256+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3257+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
32563258
elif mode == 1: # mean
3257-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3258-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3259-
max_indices = op.Expand(0, op.Shape(bag_size))
3259+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3260+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3261+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
32603262
else: # max
3261-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3262-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3263+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3264+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
32633265
# shape = (bag_size.dim[0], weight.dim[1])
32643266
dim_0 = op.Shape(bag_size, start=0, end=1)
32653267
dim_1 = op.Shape(weight, start=1, end=2)
3266-
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
3268+
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)
32673269

32683270
return result, offset2bag, bag_size, max_indices
32693271

@@ -3285,27 +3287,40 @@ def aten_embedding_bag_padding_idx(
32853287
sparse: bool = False,
32863288
per_sample_weights: Optional[TFloat] = None,
32873289
include_last_offset: bool = False,
3288-
padding_idx: int = -1,
3290+
padding_idx: Optional[int] = None,
32893291
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
32903292
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
32913293
32923294
We add default values for the attributes to accommodate _embedding_bag as well:
32933295
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
32943296
"""
3295-
assert padding_idx is not None, (
3296-
"padding_idx must not be None. This is likely a dispatcher error"
3297-
)
32983297

32993298
if per_sample_weights is None:
33003299
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
33013300
per_sample_weights = op.CastLike(per_sample_weights, weight)
33023301

3303-
# Change padding_idx to positive value, -1 means the last index
3304-
if padding_idx < 0:
3305-
padding_idx = weight.shape[0] + padding_idx
3302+
if padding_idx is not None:
3303+
# Call the existing function for handling padding_idx
3304+
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
3305+
weight,
3306+
indices,
3307+
offsets,
3308+
mode,
3309+
per_sample_weights,
3310+
include_last_offset,
3311+
padding_idx,
3312+
)
33063313

3307-
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
3308-
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
3314+
return result, offset2bag, bag_size, max_indices
3315+
3316+
# When padding_idx is None, use the standard embedding_bag implementation
3317+
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
3318+
weight,
3319+
indices,
3320+
offsets,
3321+
mode,
3322+
per_sample_weights,
3323+
include_last_offset,
33093324
)
33103325

33113326
return result, offset2bag, bag_size, max_indices
@@ -3322,6 +3337,12 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33223337
padding_idx: int,
33233338
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
33243339
neg_1 = op.Constant(value_ints=[-1])
3340+
3341+
num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight
3342+
num_embeddings_scalar = op.Squeeze(num_embeddings)
3343+
if padding_idx < 0:
3344+
padding_idx = padding_idx + num_embeddings_scalar
3345+
33253346
# Get weight out according to indices,
33263347
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
33273348
indices_weight = op.Gather(weight, indices)
@@ -3357,7 +3378,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33573378
cond_2 = j < end_pos
33583379
while cond_2:
33593380
index = op.Gather(indices, j)
3360-
if not op.Equal(index, padding_idx):
3381+
normalized_index = index
3382+
if index < 0:
3383+
normalized_index = index + num_embeddings_scalar
3384+
if not op.Equal(normalized_index, padding_idx):
33613385
# Something like the 'append' operation
33623386
curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0)
33633387
j = j + 1
@@ -3386,23 +3410,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33863410
result = op.CastLike(result, weight)
33873411

33883412
if mode == 0: # sum
3389-
offset2bag = op.Expand(0, op.Shape(indices))
3413+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices)), to=INT64.dtype)
33903414
if op.Equal(include_last_offset, True):
3391-
bag_size = op.Expand(0, op.Shape(offsets))
3415+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
3416+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
33923417
else:
3393-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3394-
max_indices = op.Expand(0, op.Shape(bag_size))
3418+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3419+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
33953420
elif mode == 1: # mean
3396-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3397-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3398-
max_indices = op.Expand(0, op.Shape(bag_size))
3421+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3422+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3423+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
33993424
else: # mode == 2, max
3400-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3401-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3425+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3426+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
34023427
# shape = (bag_size.dim[0], weight.dim[1])
34033428
dim_0 = op.Shape(bag_size, start=0, end=1)
34043429
dim_1 = op.Shape(weight, start=1, end=2)
3405-
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
3430+
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)
34063431

34073432
return result, offset2bag, bag_size, max_indices
34083433

@@ -4382,7 +4407,6 @@ def aten_grid_sampler(
43824407
padding_mode_options = ("zeros", "border", "reflection")
43834408
padding_mode_str = padding_mode_options[padding_mode]
43844409

4385-
# Only one onnx Op so don't put into private function
43864410
return op.GridSample(
43874411
input,
43884412
grid,
@@ -4408,7 +4432,6 @@ def aten_grid_sampler_2d(
44084432
padding_mode_options = ("zeros", "border", "reflection")
44094433
padding_mode_str = padding_mode_options[padding_mode]
44104434

4411-
# Only one onnx Op so don't put into private function
44124435
return op.GridSample(
44134436
input,
44144437
grid,
@@ -4698,7 +4721,7 @@ def _aten_index_onnx(
46984721
if _has_none_in_middle(indices):
46994722
# If there is None in the middle, Advanced Indexing cannot decide where to put
47004723
# the new dimensions. So it places them in the front, like GatherND does.
4701-
return op.Identity(self)
4724+
return self
47024725

47034726
# When the indices are consecutive, Advanced Indexing will place the new dimensions
47044727
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
@@ -4744,7 +4767,9 @@ def _aten_index_onnx(
47444767

47454768

47464769
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
4747-
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
4770+
def aten_index(
4771+
self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]]
4772+
) -> TensorType:
47484773
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
47494774
47504775
NOTE: Understanding `aten::index`
@@ -4764,17 +4789,19 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
47644789
47654790
None in `indices` are like fillers for dimensions that cannot be removed in the process.
47664791
"""
4792+
# Handle Boolean indexing first
4793+
if any(index is not None and index.dtype == ir.DataType.BOOL for index in indices):
4794+
return _aten_index_bool(self, indices)
47674795

47684796
index_ranks = [len(index.shape) for index in indices if index is not None]
47694797

47704798
return _aten_index_onnx(self, indices, index_ranks)
47714799

47724800

4773-
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
4774-
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
4801+
def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
47754802
index_ranks = [len(index.shape) for index in indices if index is not None]
47764803

4777-
if index_ranks[0] == 1:
4804+
if all(rank == 1 for rank in index_ranks):
47784805
# indices contains scalar only.
47794806
new_indices = [
47804807
op.Transpose(op.NonZero(index), perm=[1, 0]) if index is not None else None
@@ -4784,6 +4811,7 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
47844811
op.Squeeze(index, axes=[1]) if index is not None else None for index in new_indices
47854812
]
47864813
return _aten_index_onnx(self, new_indices, index_ranks)
4814+
47874815
else:
47884816
input_rank = len(self.shape)
47894817
# Prepare perm for transposing self tensor.
@@ -4800,15 +4828,19 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
48004828
if index is None:
48014829
self = op.Transpose(self, perm=trans_perm)
48024830
count_of_none += 1
4803-
else:
4804-
new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
4805-
result = op.GatherND(self, new_indices, batch_dims=0)
4806-
finla_rank = input_rank - (len(index.shape) - 1)
4807-
trans_perm = list(range(finla_rank))
4808-
trans_perm = trans_perm[-1:] + trans_perm[:-1]
4809-
for _ in range(count_of_none):
4810-
result = op.Transpose(result, perm=trans_perm)
4811-
return result
4831+
continue
4832+
4833+
new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
4834+
result = op.GatherND(self, new_indices, batch_dims=0)
4835+
final_rank = input_rank - (len(index.shape) - 1)
4836+
trans_perm = list(range(final_rank))
4837+
trans_perm = trans_perm[-1:] + trans_perm[:-1]
4838+
for _ in range(count_of_none):
4839+
result = op.Transpose(result, perm=trans_perm)
4840+
# FIXME(justinchuby): Even though this logic passes the tests, it still looks strange:
4841+
# why does it return early here instead of continuing to process the remaining indices?
4842+
# I think the assumption here is that there can be only one Boolean index in the indices list?
4843+
return result
48124844

48134845

48144846
def aten_index_add(
@@ -4830,7 +4862,7 @@ def aten_index_copy(
48304862
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
48314863
def aten_index_put(
48324864
self: TReal,
4833-
indices: Sequence[INT64],
4865+
indices: Sequence[Optional[Union[INT64, BOOL]]],
48344866
values: TReal,
48354867
accumulate: bool = False,
48364868
) -> TReal:
@@ -4839,6 +4871,9 @@ def aten_index_put(
48394871
See implementation of `torch.onnx.symbolic_opset11.index_put
48404872
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
48414873
"""
4874+
if any(index is not None and index.dtype == BOOL.dtype for index in indices):
4875+
return _aten_index_put_bool(self, indices, values, accumulate)
4876+
48424877
# Ensure the number of indices matches the tensor rank by appending trailing Nones.
48434878
self_rank = len(self.shape)
48444879
if len(indices) < self_rank:
@@ -4971,8 +5006,7 @@ def same_shape(other_shape: ir.Shape) -> bool:
49715006
return result
49725007

49735008

4974-
@torch_op("aten::index_put", trace_only=True)
4975-
def aten_index_put_bool(
5009+
def _aten_index_put_bool(
49765010
self: TReal,
49775011
indices: Sequence[BOOL],
49785012
values: TReal,

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def aten_col2im(
328328
else: # assert len(padding) == 4, already [w, x, y, z]
329329
pads = padding
330330

331-
# Only one ONNX op here so didn't write a private function
332331
return op.Col2Im(
333332
self,
334333
output_size,

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import re
8+
import warnings
89
from typing import Any, Callable, Generator, Optional
910

1011
import onnxscript
@@ -22,14 +23,12 @@ class OverloadedFunction:
2223
Attributes:
2324
name: Name of the op. E.g. "aten::add".
2425
overloads: Overloads function.
25-
privates: Private functions not exposed to users.
2626
complex: Support complex functions.
2727
"""
2828

2929
def __init__(self, name: str):
3030
self.name = name
3131
self.overloads: list[Any] = []
32-
self.privates: list[Any] = []
3332
self.complex: list[Any] = []
3433

3534

@@ -39,17 +38,26 @@ class Registry:
3938
def __init__(self):
4039
self._registry: dict[str, OverloadedFunction] = {}
4140

42-
def register(
43-
self, func: Any, name: str, *, private: bool = False, complex: bool = False
44-
) -> None:
41+
def register(self, func: Any, name: str, *, complex: bool = False) -> None:
4542
"""Register a function."""
46-
47-
if private:
48-
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
49-
elif complex:
50-
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
43+
overloaded_function = self._registry.setdefault(name, OverloadedFunction(name))
44+
45+
if complex:
46+
if overloaded_function.complex:
47+
warnings.warn(
48+
f"Complex overload for '{name}' already registered: {overloaded_function.complex}.",
49+
stacklevel=3,
50+
)
51+
return
52+
overloaded_function.complex.append(func)
5153
else:
52-
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
54+
if overloaded_function.overloads:
55+
warnings.warn(
56+
f"Real overload for '{name}' already registered: {overloaded_function.overloads}.",
57+
stacklevel=3,
58+
)
59+
return
60+
overloaded_function.overloads.append(func)
5361

5462
def __getitem__(self, name):
5563
return self._registry[name]
@@ -131,7 +139,10 @@ def wrapper(
131139

132140
assert registry is not None
133141
for name_ in _check_and_normalize_names(name):
134-
registry.register(processed_func, name_, private=private, complex=complex)
142+
if private:
143+
# TODO: Remove the private tag once all functions are no longer private.
144+
continue
145+
registry.register(processed_func, name_, complex=complex)
135146
return processed_func
136147

137148
return wrapper

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ onnx = ["py.typed"]
4343

4444
[tool.pytest.ini_options]
4545
addopts = "-rsfEX --tb=short --color=yes"
46+
norecursedirs = [
47+
# Skip test collection because pytest will try to import the modules twice,
48+
# causing the torchlib registry to complain that functions are redefined.
49+
"onnxscript/function_libs/torch_lib/ops",
50+
]
4651

4752
[tool.mypy]
4853
# TODO disallow_incomplete_defs = true
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnx-weekly==1.21.0.dev20251103
1+
onnx-weekly==1.21.0.dev20251215

0 commit comments

Comments
 (0)