Skip to content

Commit 3645c7e

Browse files
committed
Typecheck 25% of exir directory
1 parent c2aa614 commit 3645c7e

File tree

12 files changed

+85
-23
lines changed

12 files changed

+85
-23
lines changed

.lintrunner.toml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,56 @@ include_patterns = [
315315
# 'examples/**/*.py',
316316
'examples/openvino/**/*.py',
317317
# 'exir/**/*.py',
318+
# Phase 1: Start with simplest exir files (Batch 1)
319+
'exir/version.py',
320+
'exir/scalar_type.py',
321+
'exir/error.py',
322+
'exir/_warnings.py',
323+
'exir/types.py',
324+
# Phase 1: Batch 2 - More utility files
325+
'exir/dynamic_shape.py',
326+
'exir/memory.py',
327+
'exir/dim_order_utils.py',
328+
'exir/wrap.py',
329+
# Phase 1: Batch 3 - dialects subdirectory (5 files)
330+
'exir/dialects/__init__.py',
331+
'exir/dialects/_ops.py',
332+
'exir/dialects/backend/_ops.py',
333+
'exir/dialects/edge/dtype/supported.py',
334+
'exir/dialects/edge/dtype/utils.py',
335+
# Phase 1: Batch 3+ - operator utility
336+
'exir/operator/util.py',
337+
# Phase 1: Batch 4 - More subdirectories (6 files)
338+
'exir/program/__init__.py',
339+
'exir/program/_fake_program.py',
340+
'exir/emit/__init__.py',
341+
'exir/capture/__init__.py',
342+
'exir/capture/_config.py',
343+
'exir/verification/dev_html.py',
344+
# Phase 1: Batch 5 - Fixed problematic files (3 files)
345+
'exir/operator/manip.py',
346+
'exir/dialects/edge/dtype/runner.py',
347+
'exir/serde/schema_check.py',
348+
# Phase 1: Batch 6 - Final root-level fixes (3 files)
349+
'exir/common.py',
350+
'exir/sym_util.py',
351+
'exir/graph_module.py',
352+
# Phase 1: Batch 7 - Clean files + fixed files (7 files)
353+
'exir/schema.py',
354+
'exir/print_program.py',
355+
'exir/pass_manager.py',
356+
'exir/graph.py',
357+
'exir/control_flow.py',
358+
'exir/delegate.py',
359+
'exir/backend/partitioner.py',
360+
# Phase 1: Batch 8 - Clean files to reach 25% coverage (7 files)
361+
'exir/__init__.py',
362+
'exir/capture/_unlift.py',
363+
'exir/serde/__init__.py',
364+
'exir/serde/union.py',
365+
'exir/serde/schema.py',
366+
'exir/_serialize/__init__.py',
367+
'exir/_serialize/padding.py',
318368
# 'extension/**/*.py',
319369
'kernels/**/*.py',
320370
'profiler/**/*.py',

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,6 @@ ignore_missing_imports = True
100100

101101
[mypy-torchao.*]
102102
follow_untyped_imports = True
103+
104+
[mypy-sympy.*]
105+
ignore_missing_imports = True

exir/backend/backend_details.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99

10-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1111

1212
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1313

1414
from executorch.exir.backend.compile_spec_schema import CompileSpec
1515
from torch.export.exported_program import ExportedProgram
1616

1717

18-
def enforcedmethod(func):
19-
func.__enforcedmethod__ = True
18+
def enforcedmethod(func: Callable[..., Any]) -> Callable[..., Any]:
19+
func.__enforcedmethod__ = True # type: ignore[attr-defined]
2020
return func
2121

2222

exir/backend/partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class Partitioner(ABC):
5959
def __init__(
6060
self,
6161
spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}),
62-
):
62+
) -> None:
6363
self._spec = spec
6464

6565
def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
@@ -69,7 +69,7 @@ def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
6969
def spec(self) -> Mapping[Union[str, int, float, bool], object]:
7070
return self._spec
7171

72-
@enforcedmethod
72+
@enforcedmethod # type: ignore[misc]
7373
@abstractmethod
7474
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7575
"""

exir/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def override_logger(
104104
try:
105105
oldLevel = logging.root.level
106106
logging.root.setLevel(newLevel)
107+
oldFormatters = []
107108
if fmtstr:
108109
newformatter = logging.Formatter(fmtstr, None, "%")
109-
oldFormatters = []
110110
for handler in logging.root.handlers:
111111
oldFormatters.append(handler.formatter)
112112
handler.formatter = newformatter

exir/control_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _make_submodule(
103103
f"Expect function '{fn.__name__}' to be decorated with tracing_context.",
104104
)
105105
# pyre-ignore
106-
args = fn.__tracing_inputs__
106+
args = fn.__tracing_inputs__ # type: ignore[attr-defined]
107107
# TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways
108108
gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False)
109109
output = next(iter(reversed(gm.graph.nodes)))
@@ -122,7 +122,7 @@ def _make_submodule(
122122
output.args = tuple(output.args[0])
123123
gm.recompile()
124124
# pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`.
125-
gm.__tracing_inputs__ = args
125+
gm.__tracing_inputs__ = args # type: ignore[attr-defined]
126126
return gm
127127

128128

@@ -198,7 +198,7 @@ def wrapper(
198198

199199
return f(*args)
200200

201-
wrapper.__tracing_inputs__ = inputs # pyre-ignore
201+
wrapper.__tracing_inputs__ = inputs # type: ignore[attr-defined]
202202
return wrapper
203203

204204
return decorator

exir/delegate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
4343

4444
# pyre-ignore
45-
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
45+
def trace_call_delegate(
46+
proxy_mode: Any, func_overload: Any, lowered_module: Any, *args: Any
47+
) -> Any:
4648
# pyre-ignore
47-
def _unwrap_proxy(e):
49+
def _unwrap_proxy(e: Any) -> Any:
4850
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
4951
return e
5052
return get_proxy_slot(
@@ -151,7 +153,7 @@ def is_lowered_module(obj: Any) -> bool:
151153
def get_lowered_module_name(
152154
root: torch.nn.Module,
153155
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
154-
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
156+
lowered_module: Any, # noqa
155157
) -> str:
156158
"""
157159
Adds the given lowered_module into the given root module and returns the

exir/delegate.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ def is_lowered_module(obj: Any) -> bool: ...
1717
def get_lowered_module_name(
1818
root: torch.nn.Module,
1919
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
20-
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
20+
lowered_module: Any, # noqa
2121
) -> str: ...

exir/dialects/edge/dtype/runner.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:
3030
@staticmethod
3131
def _get_args_kwargs(
3232
inputs: Dict[str, List[BaseArg]],
33-
dtypes: Tuple[Optional[torch.dtype]],
33+
dtypes: Tuple[Optional[torch.dtype], ...],
3434
mode: ArgMode,
3535
) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]:
3636
"""Construct args and kwargs for op given dtypes."""
@@ -71,16 +71,20 @@ def run_dtypes(
7171
self,
7272
name: str,
7373
inputs: Dict[str, List[BaseArg]],
74-
dtypes: Tuple[Optional[torch.dtype]],
74+
dtypes: Tuple[Optional[torch.dtype], ...],
7575
argmode: ArgMode = ArgMode.RANDOM,
7676
) -> Tuple[
77-
bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
77+
bool,
78+
str,
79+
Tuple[Optional[torch.dtype], ...],
80+
List[BaseArg],
81+
Dict[str, BaseKwarg],
7882
]:
7983
args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode)
8084
op = get_callable(name)
8185
try:
8286
res = op(*args, **kwargs)
83-
ret_dtypes = ()
87+
ret_dtypes: Tuple[torch.dtype, ...] = ()
8488
if "returns" in inputs:
8589
ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
8690
return (True, name, dtypes + ret_dtypes, args, kwargs)
@@ -112,7 +116,11 @@ def run(
112116
argmode: ArgMode = ArgMode.ONES,
113117
) -> List[
114118
Tuple[
115-
bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
119+
bool,
120+
str,
121+
Tuple[Optional[torch.dtype], ...],
122+
List[BaseArg],
123+
Dict[str, BaseKwarg],
116124
]
117125
]:
118126
results = []

exir/graph_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _get_submodule(
4242
assert submod_node.op == "get_attr"
4343
assert isinstance(submod_node.target, str)
4444
submodule = graph_module.get_submodule(submod_node.target)
45-
# pyre-ignore
45+
assert isinstance(submodule, torch.nn.Module)
4646
return submod_node.target, submodule, node
4747

4848

@@ -67,7 +67,7 @@ def get_control_flow_submodules(
6767
if node.target is torch.ops.higher_order.map_impl:
6868
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
6969

70-
return control_flow_submodules
70+
return control_flow_submodules # type: ignore[return-value]
7171

7272

7373
def bfs_trace_with_node_process(

0 commit comments

Comments
 (0)