Skip to content

Commit 3d46d08

Browse files
committed
Typecheck 42% of exir directory
1 parent 3645c7e commit 3d46d08

File tree

15 files changed

+128
-99
lines changed

15 files changed

+128
-99
lines changed

.lintrunner.toml

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -314,57 +314,37 @@ include_patterns = [
314314
'docs/**/*.py',
315315
# 'examples/**/*.py',
316316
'examples/openvino/**/*.py',
317-
# 'exir/**/*.py',
318-
# Phase 1: Start with simplest exir files (Batch 1)
319-
'exir/version.py',
320-
'exir/scalar_type.py',
321-
'exir/error.py',
317+
# exir directory - organized by subdirectories with blacklist exclusions
318+
'exir/_serialize/**/*.py',
319+
'exir/backend/**/*.py',
320+
'exir/capture/**/*.py',
321+
'exir/dialects/**/*.py',
322+
'exir/emit/**/*.py',
323+
'exir/operator/**/*.py',
324+
'exir/passes/**/*.py',
325+
'exir/program/**/*.py',
326+
'exir/serde/**/*.py',
327+
'exir/verification/**/*.py',
328+
# exir root level files (sorted alphabetically)
329+
'exir/__init__.py',
322330
'exir/_warnings.py',
323-
'exir/types.py',
324-
# Phase 1: Batch 2 - More utility files
325-
'exir/dynamic_shape.py',
326-
'exir/memory.py',
327-
'exir/dim_order_utils.py',
328-
'exir/wrap.py',
329-
# Phase 1: Batch 3 - dialects subdirectory (5 files)
330-
'exir/dialects/__init__.py',
331-
'exir/dialects/_ops.py',
332-
'exir/dialects/backend/_ops.py',
333-
'exir/dialects/edge/dtype/supported.py',
334-
'exir/dialects/edge/dtype/utils.py',
335-
# Phase 1: Batch 3+ - operator utility
336-
'exir/operator/util.py',
337-
# Phase 1: Batch 4 - More subdirectories (6 files)
338-
'exir/program/__init__.py',
339-
'exir/program/_fake_program.py',
340-
'exir/emit/__init__.py',
341-
'exir/capture/__init__.py',
342-
'exir/capture/_config.py',
343-
'exir/verification/dev_html.py',
344-
# Phase 1: Batch 5 - Fixed problematic files (3 files)
345-
'exir/operator/manip.py',
346-
'exir/dialects/edge/dtype/runner.py',
347-
'exir/serde/schema_check.py',
348-
# Phase 1: Batch 6 - Final root-level fixes (3 files)
349331
'exir/common.py',
350-
'exir/sym_util.py',
351-
'exir/graph_module.py',
352-
# Phase 1: Batch 7 - Clean files + fixed files (7 files)
353-
'exir/schema.py',
354-
'exir/print_program.py',
355-
'exir/pass_manager.py',
356-
'exir/graph.py',
357332
'exir/control_flow.py',
358333
'exir/delegate.py',
359-
'exir/backend/partitioner.py',
360-
# Phase 1: Batch 8 - Clean files to reach 25% coverage (7 files)
361-
'exir/__init__.py',
362-
'exir/capture/_unlift.py',
363-
'exir/serde/__init__.py',
364-
'exir/serde/union.py',
365-
'exir/serde/schema.py',
366-
'exir/_serialize/__init__.py',
367-
'exir/_serialize/padding.py',
334+
'exir/dim_order_utils.py',
335+
'exir/dynamic_shape.py',
336+
'exir/error.py',
337+
'exir/graph.py',
338+
'exir/graph_module.py',
339+
'exir/memory.py',
340+
'exir/pass_manager.py',
341+
'exir/print_program.py',
342+
'exir/scalar_type.py',
343+
'exir/schema.py',
344+
'exir/sym_util.py',
345+
'exir/types.py',
346+
'exir/version.py',
347+
'exir/wrap.py',
368348
# 'extension/**/*.py',
369349
'kernels/**/*.py',
370350
'profiler/**/*.py',
@@ -380,6 +360,50 @@ exclude_patterns = [
380360
'scripts/check_binary_dependencies.py',
381361
'profiler/test/test_profiler_e2e.py',
382362
'backends/arm/test/**',
363+
# exir exclusions - files with mypy errors
364+
'exir/backend/test/**',
365+
'exir/backend/utils.py',
366+
'exir/capture/_capture.py',
367+
'exir/dialects/backend/test/**',
368+
'exir/dialects/edge/arg/model.py',
369+
'exir/dialects/edge/op/test/**',
370+
'exir/dialects/edge/spec/**',
371+
'exir/dialects/edge/test/**',
372+
'exir/dialects/test/**',
373+
'exir/emit/_emitter.py',
374+
'exir/emit/test/**',
375+
'exir/lowered_backend_module.py',
376+
'exir/memory_planning.py',
377+
'exir/operator/convert.py',
378+
'exir/operator/test/**',
379+
'exir/pass_base.py',
380+
'exir/passes/__init__.py',
381+
'exir/passes/_quant_patterns_and_replacements.py',
382+
'exir/passes/const_prop_pass.py',
383+
'exir/passes/constant_prop_pass.py',
384+
'exir/passes/dynamic_shape_prop_pass.py',
385+
'exir/passes/executorch_prim_ops_registry.py',
386+
'exir/passes/memory_planning_pass.py',
387+
'exir/passes/prune_empty_tensors_pass.py',
388+
'exir/passes/quant_fusion_pass.py',
389+
'exir/passes/quantize_io_pass.py',
390+
'exir/passes/remove_mixed_type_operators.py',
391+
'exir/passes/remove_noop_pass.py',
392+
'exir/passes/replace_view_copy_with_view_pass.py',
393+
'exir/passes/spec_prop_pass.py',
394+
'exir/passes/sym_shape_eval_pass.py',
395+
'exir/passes/sym_to_tensor_pass.py',
396+
'exir/passes/weights_to_outputs_pass.py',
397+
'exir/program/_program.py',
398+
'exir/program/test/**',
399+
'exir/serde/export_serialize.py',
400+
'exir/serde/serialize.py',
401+
'exir/tensor.py',
402+
'exir/tests/**',
403+
'exir/tracer.py',
404+
'exir/verification/arg_validator.py',
405+
'exir/verification/test/**',
406+
'exir/_serialize/test/**',
383407
]
384408
command = [
385409
'python3',

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,6 @@ follow_untyped_imports = True
103103

104104
[mypy-sympy.*]
105105
ignore_missing_imports = True
106+
107+
[mypy-executorch.exir.verification.bindings]
108+
ignore_missing_imports = True

exir/_serialize/_dataclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,5 @@ class Example
141141
if isinstance(T, enum.EnumMeta):
142142
data[key] = T[value]
143143
else:
144-
data[key] = T(value)
145-
return cls(**data)
144+
data[key] = T(value) # type: ignore[operator]
145+
return cls(**data) # type: ignore[operator]

exir/_serialize/_flatbuffer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def _run_flatc(args: Sequence[str]) -> None:
193193
subprocess.run([flatc_path] + list(args), check=True)
194194
else:
195195
# Expect the `flatc` tool to be on the system path or set as an env var.
196-
flatc_path = os.getenv("FLATC_EXECUTABLE")
197-
if not flatc_path:
198-
flatc_path = "flatc"
199-
subprocess.run([flatc_path] + list(args), check=True)
196+
flatc_executable = os.getenv("FLATC_EXECUTABLE")
197+
if not flatc_executable:
198+
flatc_executable = "flatc"
199+
subprocess.run([flatc_executable] + list(args), check=True)
200200

201201

202202
def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:

exir/_serialize/_named_data_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def _add_named_data_to_map(
121121
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
122122
raise ValueError(
123123
f"Duplicate key {key} with different data. "
124-
f"Existing data: {self.buffers[buffer_idx].buffer}. "
125-
f"New data: {data}."
124+
f"Existing data: {self.buffers[buffer_idx].buffer!r}. "
125+
f"New data: {data!r}." # type: ignore[str-bytes-safe]
126126
)
127127
self.buffers[buffer_idx].alignment = math.lcm(
128128
self.buffers[buffer_idx].alignment, alignment

exir/_serialize/_serialize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Set, Tuple
9+
from typing import Dict, List, Optional, Set, Tuple
1010

1111
from executorch.exir._serialize import _serialize_pte_binary
1212

@@ -102,10 +102,10 @@ def serialize_for_executorch(
102102
)
103103

104104
for tag in all_external_tags:
105-
buffers = []
105+
buffers: List[bytes] = []
106106
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
107107
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
108-
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
108+
fqn_to_index = emitter_output.external_constant_map.get(tag, {}) # type: ignore[union-attr]
109109
# Create a TensorEntry for each external tensor.
110110
for fqn, index in fqn_to_index.items():
111111
assert fqn in fqn_to_tensor_layout
@@ -118,13 +118,13 @@ def serialize_for_executorch(
118118
# Extract external data.
119119
key_to_data: Dict[str, DataEntry] = {}
120120
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
121-
key_to_buffer_index = named_data.external_data.get(tag, {})
121+
key_to_buffer_index = named_data.external_data.get(tag, {}) # type: ignore[union-attr]
122122
for key, index in key_to_buffer_index.items():
123123
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
124124
key_to_data[key] = DataEntry(
125-
len(buffers), named_data.buffers[index].alignment
125+
len(buffers), named_data.buffers[index].alignment # type: ignore[union-attr]
126126
)
127-
buffers.append(named_data.buffers[index].buffer)
127+
buffers.append(named_data.buffers[index].buffer) # type: ignore[union-attr]
128128

129129
# Serialize into PTD file.
130130
ptd_files[tag] = data_serializer.serialize(

exir/backend/backend_api.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def to_backend(
123123
compile_specs=compile_specs,
124124
named_data_store_output=preprocess_result.data_store_output,
125125
)
126-
lowered_module.meta = {
126+
lowered_module.meta = { # type: ignore[assignment]
127127
"debug_handle_map": preprocess_result.debug_handle_map
128128
}
129129
return lowered_module
@@ -311,7 +311,7 @@ def _partition_and_lower_one_graph_module(
311311
is_submodule,
312312
)
313313

314-
lowered_submodule = to_backend(
314+
lowered_submodule = to_backend( # type: ignore[call-arg]
315315
delegation_spec.backend_id,
316316
submodule_program,
317317
delegation_spec.compile_specs,
@@ -449,7 +449,7 @@ def _create_partitions_in_graph_module(
449449
owning_program: ExportedProgram,
450450
is_submodule: bool,
451451
) -> Dict[str, List[torch.fx.Node]]:
452-
backend_id_to_submodule_name = {}
452+
backend_id_to_submodule_name: Dict[str, List[str]] = {}
453453
for tag, delegation_spec in partition_result.partition_tags.items():
454454
# Create partition with nodes containing this tag. There should only be
455455
# one contained submodule per tag
@@ -517,10 +517,12 @@ def _create_partitions_in_graph_module(
517517
# in future edits to the graph. As a result, we just keep track of the node's name
518518
# and at the end we search for this node in our final graph module
519519
backend_id_to_submodule_name[delegation_spec.backend_id].append(
520-
call_module_node.target
520+
call_module_node.target # type: ignore[arg-type]
521521
)
522522

523-
created_submodule_nodes = {key: [] for key in backend_id_to_submodule_name.keys()}
523+
created_submodule_nodes: Dict[str, List[torch.fx.Node]] = {
524+
key: [] for key in backend_id_to_submodule_name.keys()
525+
}
524526
for backend_id, submodule_name in backend_id_to_submodule_name.items():
525527
for node in tagged_graph_module.graph.nodes:
526528
if node.op == "call_module" and node.target in submodule_name:
@@ -615,7 +617,7 @@ def lower_all_submodules_to_backend(
615617
compile_specs=compile_spec,
616618
named_data_store_output=preprocess_result.data_store_output,
617619
)
618-
lowered_module.meta = {
620+
lowered_module.meta = { # type: ignore[assignment]
619621
"debug_handle_map": preprocess_result.debug_handle_map,
620622
}
621623
is_submodule = call_submodule_node.meta["is_submodule"]
@@ -698,7 +700,7 @@ def to_backend(
698700
method_to_partitioner = method_edge_program_partitioners.method_to_partitioner
699701

700702
partitioned_and_lowered_exported_programs = {}
701-
backend_id_to_method_submodules_map = {}
703+
backend_id_to_method_submodules_map: Dict[str, Dict[str, List[torch.fx.Node]]] = {}
702704
method_to_tagged_exported_program = {}
703705

704706
for method_name, partitioner_instance in method_to_partitioner.items():

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ class PartitionerConfig(ABC):
5252
the specified backend.
5353
"""
5454

55-
@classmethod
56-
@property
55+
@property # type: ignore[misc]
5756
@abstractmethod
58-
def target_name(cls) -> str:
57+
def target_name(self) -> str:
5958
"""
6059
Target name for this partitioner config. When the Config-Based Partitioner
6160
encounters a node with a matching target name, it uses this config's methods to
@@ -138,7 +137,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
138137
"""
139138
if node.op != "call_function":
140139
return False
141-
target_name = format_target_name(node.target.__name__) # pyre-ignore
140+
target_name = format_target_name(node.target.__name__) # type: ignore[union-attr]
142141

143142
if target_name in self.target_partitioner_configs:
144143
config = self.target_partitioner_configs[target_name]

exir/backend/canonical_partitioners/duplicate_constant_node_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def _get_attribute_or_constants(
2828
if maybe_param is not None:
2929
constant_or_attribute = maybe_param
3030
elif maybe_buffer is not None:
31-
constant_or_attribute = maybe_buffer
31+
constant_or_attribute = maybe_buffer # type: ignore[assignment]
3232
elif maybe_lifted_tensor is not None:
33-
constant_or_attribute = maybe_lifted_tensor
33+
constant_or_attribute = maybe_lifted_tensor # type: ignore[assignment]
3434
return constant_or_attribute
3535

3636

exir/capture/_capture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
122122
outputs=[],
123123
# pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
124124
# `Union[Tensor, Module]`.
125-
in_spec=in_spec,
125+
in_spec=in_spec, # type: ignore[arg-type]
126126
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
127127
# `Union[Tensor, Module]`.
128-
out_spec=out_spec,
128+
out_spec=out_spec, # type: ignore[arg-type]
129129
),
130130
)
131131
],
@@ -207,7 +207,7 @@ def capture( # noqa: C901
207207
if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module):
208208
with patch_forward(f.__self__, f):
209209
ep = export(
210-
cast(torch.nn.Module, f.__self__),
210+
f.__self__, # type: ignore[redundant-cast]
211211
args,
212212
dynamic_shapes=dynamic_shapes,
213213
strict=True,
@@ -272,7 +272,7 @@ def graph_with_interpreter(*args):
272272
graph_with_interpreter,
273273
remove="mutations_and_views",
274274
)
275-
assert isinstance(functionalized_callable, Callable)
275+
assert callable(functionalized_callable) # type: ignore[arg-type]
276276

277277
if config.enable_dynamic_shape:
278278
fake_tensor_mode = FakeTensorMode(
@@ -357,7 +357,7 @@ def convert_to_fake(x):
357357
in_spec=in_spec,
358358
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
359359
# `Union[None, TreeSpec, Tensor, Module]`.
360-
out_spec=out_spec,
360+
out_spec=out_spec, # type: ignore[arg-type]
361361
),
362362
)
363363
],

0 commit comments

Comments
 (0)