Skip to content

Commit 5ccf3ab

Browse files
authored
[torchlib] Consolidate all overloads and prevent new ones from being created (#2621)
This PR implements #2580 by combining all overloads in torchlib and remove the ability to register new ones. It is done in a BC compatible fashion and should work with released versions of PyTorch. From now on all logic for a single aten OpOverload should be implemented by a single torchlib function to ensure 1-to-1 mapping. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 726be2b commit 5ccf3ab

File tree

6 files changed

+59
-49
lines changed

6 files changed

+59
-49
lines changed

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]:
2323
for op in registration.default_registry.values():
24-
for func in (*op.overloads, *op.privates, *op.complex):
24+
for func in (*op.overloads, *op.complex):
2525
if isinstance(func, onnxscript.OnnxFunction):
2626
yield func
2727

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4382,7 +4382,6 @@ def aten_grid_sampler(
43824382
padding_mode_options = ("zeros", "border", "reflection")
43834383
padding_mode_str = padding_mode_options[padding_mode]
43844384

4385-
# Only one onnx Op so don't put into private function
43864385
return op.GridSample(
43874386
input,
43884387
grid,
@@ -4408,7 +4407,6 @@ def aten_grid_sampler_2d(
44084407
padding_mode_options = ("zeros", "border", "reflection")
44094408
padding_mode_str = padding_mode_options[padding_mode]
44104409

4411-
# Only one onnx Op so don't put into private function
44124410
return op.GridSample(
44134411
input,
44144412
grid,
@@ -4698,7 +4696,7 @@ def _aten_index_onnx(
46984696
if _has_none_in_middle(indices):
46994697
# If there is None in the middle, Advanced Indexing cannot decide where to put
47004698
# the new dimensions. So it places them in the front, like GatherND does.
4701-
return op.Identity(self)
4699+
return self
47024700

47034701
# When the indices are consecutive, Advanced Indexing will place the new dimensions
47044702
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
@@ -4744,7 +4742,9 @@ def _aten_index_onnx(
47444742

47454743

47464744
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
4747-
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
4745+
def aten_index(
4746+
self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]]
4747+
) -> TensorType:
47484748
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
47494749
47504750
NOTE: Understanding `aten::index`
@@ -4764,17 +4764,19 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
47644764
47654765
None in `indices` are like fillers for dimensions that cannot be removed in the process.
47664766
"""
4767+
# Handle Boolean indexing first
4768+
if any(index is not None and index.dtype == ir.DataType.BOOL for index in indices):
4769+
return _aten_index_bool(self, indices)
47674770

47684771
index_ranks = [len(index.shape) for index in indices if index is not None]
47694772

47704773
return _aten_index_onnx(self, indices, index_ranks)
47714774

47724775

4773-
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
4774-
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
4776+
def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
47754777
index_ranks = [len(index.shape) for index in indices if index is not None]
47764778

4777-
if index_ranks[0] == 1:
4779+
if all(rank == 1 for rank in index_ranks):
47784780
# indices contains scalar only.
47794781
new_indices = [
47804782
op.Transpose(op.NonZero(index), perm=[1, 0]) if index is not None else None
@@ -4784,6 +4786,7 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
47844786
op.Squeeze(index, axes=[1]) if index is not None else None for index in new_indices
47854787
]
47864788
return _aten_index_onnx(self, new_indices, index_ranks)
4789+
47874790
else:
47884791
input_rank = len(self.shape)
47894792
# Prepare perm for transposing self tensor.
@@ -4800,15 +4803,19 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
48004803
if index is None:
48014804
self = op.Transpose(self, perm=trans_perm)
48024805
count_of_none += 1
4803-
else:
4804-
new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
4805-
result = op.GatherND(self, new_indices, batch_dims=0)
4806-
finla_rank = input_rank - (len(index.shape) - 1)
4807-
trans_perm = list(range(finla_rank))
4808-
trans_perm = trans_perm[-1:] + trans_perm[:-1]
4809-
for _ in range(count_of_none):
4810-
result = op.Transpose(result, perm=trans_perm)
4811-
return result
4806+
continue
4807+
4808+
new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
4809+
result = op.GatherND(self, new_indices, batch_dims=0)
4810+
final_rank = input_rank - (len(index.shape) - 1)
4811+
trans_perm = list(range(final_rank))
4812+
trans_perm = trans_perm[-1:] + trans_perm[:-1]
4813+
for _ in range(count_of_none):
4814+
result = op.Transpose(result, perm=trans_perm)
4815+
# FIXME(justinchuby): Even though this logic passes the tests, it still looks strange:
4816+
# why does it return early here instead of continuing to process the remaining indices?
4817+
# I think the assumption here is that there can be only one Boolean index in the indices list?
4818+
return result
48124819

48134820

48144821
def aten_index_add(
@@ -4830,7 +4837,7 @@ def aten_index_copy(
48304837
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
48314838
def aten_index_put(
48324839
self: TReal,
4833-
indices: Sequence[INT64],
4840+
indices: Sequence[Optional[Union[INT64, BOOL]]],
48344841
values: TReal,
48354842
accumulate: bool = False,
48364843
) -> TReal:
@@ -4839,6 +4846,9 @@ def aten_index_put(
48394846
See implementation of `torch.onnx.symbolic_opset11.index_put
48404847
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
48414848
"""
4849+
if any(index is not None and index.dtype == BOOL.dtype for index in indices):
4850+
return _aten_index_put_bool(self, indices, values, accumulate)
4851+
48424852
# Ensure the number of indices matches the tensor rank by appending trailing Nones.
48434853
self_rank = len(self.shape)
48444854
if len(indices) < self_rank:
@@ -4971,8 +4981,7 @@ def same_shape(other_shape: ir.Shape) -> bool:
49714981
return result
49724982

49734983

4974-
@torch_op("aten::index_put", trace_only=True)
4975-
def aten_index_put_bool(
4984+
def _aten_index_put_bool(
49764985
self: TReal,
49774986
indices: Sequence[BOOL],
49784987
values: TReal,

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def aten_col2im(
328328
else: # assert len(padding) == 4, already [w, x, y, z]
329329
pads = padding
330330

331-
# Only one ONNX op here so didn't write a private function
332331
return op.Col2Im(
333332
self,
334333
output_size,

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import re
8+
import warnings
89
from typing import Any, Callable, Generator, Optional
910

1011
import onnxscript
@@ -22,14 +23,12 @@ class OverloadedFunction:
2223
Attributes:
2324
name: Name of the op. E.g. "aten::add".
2425
overloads: Overloads function.
25-
privates: Private functions not exposed to users.
2626
complex: Support complex functions.
2727
"""
2828

2929
def __init__(self, name: str):
3030
self.name = name
3131
self.overloads: list[Any] = []
32-
self.privates: list[Any] = []
3332
self.complex: list[Any] = []
3433

3534

@@ -39,17 +38,26 @@ class Registry:
3938
def __init__(self):
4039
self._registry: dict[str, OverloadedFunction] = {}
4140

42-
def register(
43-
self, func: Any, name: str, *, private: bool = False, complex: bool = False
44-
) -> None:
41+
def register(self, func: Any, name: str, *, complex: bool = False) -> None:
4542
"""Register a function."""
46-
47-
if private:
48-
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
49-
elif complex:
50-
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
43+
overloaded_function = self._registry.setdefault(name, OverloadedFunction(name))
44+
45+
if complex:
46+
if overloaded_function.complex:
47+
warnings.warn(
48+
f"Complex overload for '{name}' already registered: {overloaded_function.complex}.",
49+
stacklevel=3,
50+
)
51+
return
52+
overloaded_function.complex.append(func)
5153
else:
52-
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
54+
if overloaded_function.overloads:
55+
warnings.warn(
56+
f"Real overload for '{name}' already registered: {overloaded_function.overloads}.",
57+
stacklevel=3,
58+
)
59+
return
60+
overloaded_function.overloads.append(func)
5361

5462
def __getitem__(self, name):
5563
return self._registry[name]
@@ -131,7 +139,10 @@ def wrapper(
131139

132140
assert registry is not None
133141
for name_ in _check_and_normalize_names(name):
134-
registry.register(processed_func, name_, private=private, complex=complex)
142+
if private:
143+
# TODO: Remove the private tag once all functions are no longer private.
144+
continue
145+
registry.register(processed_func, name_, complex=complex)
135146
return processed_func
136147

137148
return wrapper

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ onnx = ["py.typed"]
4343

4444
[tool.pytest.ini_options]
4545
addopts = "-rsfEX --tb=short --color=yes"
46+
norecursedirs = [
47+
# Skip test collection because pytest will try to import the modules twice,
48+
# causing the torchlib registry to complain that functions are redefined.
49+
"onnxscript/function_libs/torch_lib/ops",
50+
]
4651

4752
[tool.mypy]
4853
# TODO disallow_incomplete_defs = true

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -728,23 +728,10 @@ def _where_input_wrangler(
728728
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
729729
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
730730
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index),
731-
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool),
732-
TorchLibOpInfo(
733-
"index_put_bool",
734-
core_ops.aten_index_put_bool,
735-
input_wrangler=_index_put_input_wrangler,
736-
).skip(
737-
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
738-
reason="this Aten overload only supports tensor(bool) as indices",
739-
),
731+
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index),
740732
TorchLibOpInfo(
741733
"index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler
742-
)
743-
.skip(
744-
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
745-
reason="this Aten overload only supports tensor(int) as indices",
746-
)
747-
.xfail(
734+
).skip(
748735
dtypes=(torch.float16,),
749736
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
750737
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
@@ -1871,7 +1858,6 @@ def _where_input_wrangler(
18711858
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
18721859
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
18731860
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
1874-
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
18751861
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
18761862
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
18771863
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))

0 commit comments

Comments
 (0)