Skip to content

Commit 500ab0e

Browse files
aorenstefacebook-github-bot
authored andcommitted
Improve torch.ops typing (pytorch#154555)
Summary: X-link: pytorch/executorch#11276 Cloned pytorch#153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Test Plan: CI Reviewed By: bobrenjc93, mergennachin Differential Revision: D75497142
1 parent ac86ec0 commit 500ab0e

File tree

16 files changed

+180
-126
lines changed

16 files changed

+180
-126
lines changed

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,13 @@ ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
450450
# and torch/csrc/jit/python/init.cpp
451451
def _maybe_call_torch_function_for_op_packet(
452452
op_overload_packet: Any,
453-
args: Any,
454-
kwargs: Any,
453+
*args: Any,
454+
**kwargs: Any,
455455
) -> Any: ...
456456
def _check_schema_allow_fake_script_object(
457457
schema: FunctionSchema,
458-
args: Any,
459-
kwargs: Any,
458+
*args: Any,
459+
**kwargs: Any,
460460
) -> _bool: ...
461461
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
462462
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
@@ -1630,6 +1630,7 @@ class Generator:
16301630
class _DispatchOperatorHandle:
16311631
def schema(self) -> FunctionSchema: ...
16321632
def debug(self) -> str: ...
1633+
def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
16331634

16341635
class _DispatchModule:
16351636
def reset(self) -> None: ...
@@ -1826,7 +1827,7 @@ class _SetExcludeDispatchKeyGuard:
18261827
# Defined in torch/csrc/utils/schema_info.h
18271828

18281829
class _SchemaInfo:
1829-
def __init__(self, schema: _int) -> None: ...
1830+
def __init__(self, schema: FunctionSchema) -> None: ...
18301831
@overload
18311832
def is_mutable(self) -> _bool: ...
18321833
@overload

torch/_dispatch/python.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import unittest.mock
44
from collections.abc import Iterator
55
from contextlib import contextmanager
6+
from typing import Callable, TypeVar, Union
7+
from typing_extensions import ParamSpec
68

79
import torch
810
import torch._C
911
import torch._ops
1012
import torch.utils._python_dispatch
1113
import torch.utils._pytree as pytree
14+
from torch._C import DispatchKey
1215

1316

1417
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
@@ -19,6 +22,9 @@
1922

2023
CROSSREF_FUNCTIONALIZE = False
2124

25+
_P = ParamSpec("_P")
26+
_T = TypeVar("_T")
27+
2228

2329
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
2430
"""
@@ -103,14 +109,16 @@ def _fmt(a: object) -> object:
103109
return a
104110

105111

106-
def make_crossref_functionalize(op, final_key):
112+
def make_crossref_functionalize(
113+
op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
114+
) -> Union[Callable[_P, _T], DispatchKey]:
107115
from torch._subclasses.fake_tensor import FakeTensorMode
108116

109117
# This case is pretty weird, suppress it for now
110118
if op == torch.ops.aten.lift_fresh.default:
111119
return final_key
112120

113-
def handler(*args, **kwargs):
121+
def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
114122
fake_mode = FakeTensorMode()
115123

116124
def fakeify_defun(t):

torch/_export/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from collections.abc import Sequence
88
from contextlib import contextmanager
9-
from typing import Any, Optional, Union
9+
from typing import Any, Callable, Optional, Union
1010

1111
import torch
1212
import torch.export._trace
@@ -229,7 +229,7 @@ def get_dtype_as_int(tensor):
229229
# Those operators will be automatically populated to a instance method
230230
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
231231
# Please check __init__ for method population implementations.
232-
kind_to_standard_operators = {
232+
kind_to_standard_operators: dict[str, Callable[..., Any]] = {
233233
"prim::max": builtins.max,
234234
"prim::min": builtins.min,
235235
"prim::TupleIndex": operator.getitem,

torch/_export/passes/functionalize_side_effectful_ops_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
aten = torch.ops.aten
1616

1717
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
18-
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
18+
aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
1919
aten._assert_async.msg: aten._functional_assert_async.msg,
2020
}
2121

torch/_functorch/partitioners.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os.path
1111
from collections import defaultdict
1212
from dataclasses import dataclass, replace
13-
from typing import Callable, Optional, TYPE_CHECKING, Union
13+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
1414

1515
import torch
1616
import torch._inductor.inductor_prims
@@ -2067,7 +2067,9 @@ def get_default_op_list() -> OpTypes:
20672067
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
20682068
recomputable_ops = OrderedSet(default_recomputable_ops)
20692069

2070-
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
2070+
random_ops = OrderedSet[Callable[..., Any]](
2071+
[aten.native_dropout, aten.rand_like, aten.randn_like]
2072+
)
20712073
compute_intensive_ops = [
20722074
aten.mm,
20732075
aten.convolution,

torch/_higher_order_ops/auto_functionalize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
_has_gen_schema,
1515
call_op,
1616
HopInstance,
17+
HopSchema,
1718
materialize_callable_in_args,
1819
unique_graph_id,
1920
)
@@ -835,15 +836,15 @@ def auto_functionalized_v2_dense(
835836
_only_clone_these_bases = tuple(range(len(_all_bases)))
836837

837838
if isinstance(_mutable_op, OpOverload):
838-
schema = _mutable_op._schema
839+
schema: torch._C.FunctionSchema = _mutable_op._schema
839840
else:
840841
schema = pytree.tree_unflatten([], kwargs.pop("_op_schema")).schema
841842

842-
_mutable_op = (
843-
_mutable_op
844-
if isinstance(_mutable_op, OpOverload)
845-
else HopInstance(_mutable_op, schema)
846-
)
843+
if isinstance(_mutable_op, OpOverload):
844+
_callable_op: Union[HopInstance, OpOverload] = _mutable_op
845+
else:
846+
assert isinstance(schema, HopSchema)
847+
_callable_op = HopInstance(_mutable_op, schema)
847848

848849
op_kwargs_new, all_bases_new = _generate_new_op_kwargs_from_bases(
849850
schema,
@@ -853,7 +854,7 @@ def auto_functionalized_v2_dense(
853854
)
854855

855856
out = call_op(
856-
_mutable_op,
857+
_callable_op,
857858
tuple(),
858859
op_kwargs_new,
859860
)

torch/_higher_order_ops/effects.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ class _EffectType(Enum):
2424
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
2525

2626

27-
SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
28-
{
29-
torch.ops.aten._print.default: _EffectType.ORDERED,
30-
call_torchbind: _EffectType.ORDERED,
31-
}
27+
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
28+
[
29+
(torch.ops.aten._print.default, _EffectType.ORDERED),
30+
(call_torchbind, _EffectType.ORDERED),
31+
]
3232
)
3333

3434

torch/_inductor/decomposition.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import typing
88
from typing import Any, Callable, Optional, TypeVar, Union
9-
from typing_extensions import ParamSpec
9+
from typing_extensions import ParamSpec, TypeAlias
1010

1111
import torch
1212
import torch._decomp as decomp
@@ -51,6 +51,10 @@
5151
_T = TypeVar("_T")
5252
_P = ParamSpec("_P")
5353

54+
_GenericOperator: TypeAlias = Union[
55+
torch._ops.OperatorBase, torch._ops.OpOverloadPacket
56+
]
57+
5458
log = logging.getLogger(__name__)
5559
aten = torch.ops.aten
5660
prims = torch.ops.prims
@@ -108,7 +112,7 @@
108112

109113
# Remove unwanted decompositions included via the core ATen decompositions from
110114
# the Inductor decomp table.
111-
decomps_to_exclude = [
115+
decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
112116
aten._unsafe_index,
113117
aten._unsafe_masked_index,
114118
aten._unsafe_masked_index_put_accumulate,
@@ -132,9 +136,9 @@
132136

133137

134138
def register_decomposition(
135-
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
139+
ops: Union[_GenericOperator, list[_GenericOperator]],
136140
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
137-
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
141+
for op in ops if isinstance(ops, list) else [ops]:
138142
if op in decompositions:
139143
log.warning("duplicate decomp: %s", ops)
140144
return decomp.register_decomposition(ops, decompositions)

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import functools
33
import operator
44
from functools import reduce
5-
from typing import Any
5+
from typing import Any, Callable
66

77
import torch
88
from torch._dynamo.utils import counters
@@ -129,7 +129,7 @@ def pack_linear(
129129
transpose_weight_node = packed_weight_node.args[0]
130130
if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation:
131131
packed_linear_inputs += (bias, "none", [], "")
132-
packed_linear_op = mkldnn._linear_pointwise.default
132+
packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default
133133
else:
134134
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
135135
packed_linear_op = torch.ops.mkl._mkl_linear

torch/_inductor/fx_passes/reinplace.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import logging
44
import operator
55
from collections import defaultdict
6+
from collections.abc import Sequence
67
from dataclasses import dataclass
7-
from typing import Any, Callable, Union
8+
from typing import Any, Callable, cast, Union
89

910
import torch
11+
import torch.fx.node
1012
from torch._C._dynamo.guards import compute_overlapping_tensors
1113
from torch._dispatch.python import enable_python_dispatcher
1214
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
@@ -176,7 +178,12 @@ def _decompose_scatter_mutating(
176178

177179
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
178180
_, _, view_ops = node.args
179-
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
181+
view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
182+
return any(
183+
target in _ALWAYS_MUTATING_SCATTER_OPS
184+
for view in view_ops
185+
if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload)
186+
)
180187

181188

182189
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
@@ -267,6 +274,7 @@ def handle_view_scatter(node: torch.fx.Node):
267274
assert len(node.args) >= 2
268275
inp, src = node.args[:2]
269276

277+
assert isinstance(node.target, torch._ops.OpOverload)
270278
scatter_view_op = ViewOp(
271279
_SCATTER_OP_TO_VIEW[node.target],
272280
args=node.args[2:],
@@ -331,7 +339,7 @@ def can_fuse():
331339
handle_view_scatter(node)
332340

333341

334-
inplaceable_ops = {
342+
inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
335343
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
336344
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
337345
_generalized_scatter: InplaceableOp(
@@ -343,7 +351,7 @@ def can_fuse():
343351

344352
try:
345353
c10d_functional = torch.ops._c10d_functional
346-
inplaceable_collective_ops = {
354+
inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = {
347355
c10d_functional.all_reduce.default: InplaceableOp(
348356
c10d_functional.all_reduce_.default, 0
349357
),

0 commit comments

Comments
 (0)