Skip to content

Commit fa9d3b3

Browse files
committed
Typecheck 42% of exir directory
1 parent 3645c7e commit fa9d3b3

File tree

12 files changed

+119
-86
lines changed

12 files changed

+119
-86
lines changed

.lintrunner.toml

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -314,57 +314,37 @@ include_patterns = [
314314
'docs/**/*.py',
315315
# 'examples/**/*.py',
316316
'examples/openvino/**/*.py',
317-
# 'exir/**/*.py',
318-
# Phase 1: Start with simplest exir files (Batch 1)
319-
'exir/version.py',
320-
'exir/scalar_type.py',
321-
'exir/error.py',
317+
# exir directory - organized by subdirectories with blacklist exclusions
318+
'exir/_serialize/**/*.py',
319+
'exir/backend/**/*.py',
320+
'exir/capture/**/*.py',
321+
'exir/dialects/**/*.py',
322+
'exir/emit/**/*.py',
323+
'exir/operator/**/*.py',
324+
'exir/passes/**/*.py',
325+
'exir/program/**/*.py',
326+
'exir/serde/**/*.py',
327+
'exir/verification/**/*.py',
328+
# exir root level files (sorted alphabetically)
329+
'exir/__init__.py',
322330
'exir/_warnings.py',
323-
'exir/types.py',
324-
# Phase 1: Batch 2 - More utility files
325-
'exir/dynamic_shape.py',
326-
'exir/memory.py',
327-
'exir/dim_order_utils.py',
328-
'exir/wrap.py',
329-
# Phase 1: Batch 3 - dialects subdirectory (5 files)
330-
'exir/dialects/__init__.py',
331-
'exir/dialects/_ops.py',
332-
'exir/dialects/backend/_ops.py',
333-
'exir/dialects/edge/dtype/supported.py',
334-
'exir/dialects/edge/dtype/utils.py',
335-
# Phase 1: Batch 3+ - operator utility
336-
'exir/operator/util.py',
337-
# Phase 1: Batch 4 - More subdirectories (6 files)
338-
'exir/program/__init__.py',
339-
'exir/program/_fake_program.py',
340-
'exir/emit/__init__.py',
341-
'exir/capture/__init__.py',
342-
'exir/capture/_config.py',
343-
'exir/verification/dev_html.py',
344-
# Phase 1: Batch 5 - Fixed problematic files (3 files)
345-
'exir/operator/manip.py',
346-
'exir/dialects/edge/dtype/runner.py',
347-
'exir/serde/schema_check.py',
348-
# Phase 1: Batch 6 - Final root-level fixes (3 files)
349331
'exir/common.py',
350-
'exir/sym_util.py',
351-
'exir/graph_module.py',
352-
# Phase 1: Batch 7 - Clean files + fixed files (7 files)
353-
'exir/schema.py',
354-
'exir/print_program.py',
355-
'exir/pass_manager.py',
356-
'exir/graph.py',
357332
'exir/control_flow.py',
358333
'exir/delegate.py',
359-
'exir/backend/partitioner.py',
360-
# Phase 1: Batch 8 - Clean files to reach 25% coverage (7 files)
361-
'exir/__init__.py',
362-
'exir/capture/_unlift.py',
363-
'exir/serde/__init__.py',
364-
'exir/serde/union.py',
365-
'exir/serde/schema.py',
366-
'exir/_serialize/__init__.py',
367-
'exir/_serialize/padding.py',
334+
'exir/dim_order_utils.py',
335+
'exir/dynamic_shape.py',
336+
'exir/error.py',
337+
'exir/graph.py',
338+
'exir/graph_module.py',
339+
'exir/memory.py',
340+
'exir/pass_manager.py',
341+
'exir/print_program.py',
342+
'exir/scalar_type.py',
343+
'exir/schema.py',
344+
'exir/sym_util.py',
345+
'exir/types.py',
346+
'exir/version.py',
347+
'exir/wrap.py',
368348
# 'extension/**/*.py',
369349
'kernels/**/*.py',
370350
'profiler/**/*.py',
@@ -380,6 +360,55 @@ exclude_patterns = [
380360
'scripts/check_binary_dependencies.py',
381361
'profiler/test/test_profiler_e2e.py',
382362
'backends/arm/test/**',
363+
# exir exclusions - files with mypy errors
364+
'exir/backend/backend_api.py',
365+
'exir/backend/canonical_partitioners/config_partitioner.py',
366+
'exir/backend/canonical_partitioners/duplicate_constant_node_pass.py',
367+
'exir/backend/canonical_partitioners/duplicate_dequant_node_pass.py',
368+
'exir/backend/canonical_partitioners/pattern_op_partitioner.py',
369+
'exir/backend/test/**',
370+
'exir/backend/utils.py',
371+
'exir/capture/_capture.py',
372+
'exir/dialects/backend/test/**',
373+
'exir/dialects/edge/arg/model.py',
374+
'exir/dialects/edge/op/test/**',
375+
'exir/dialects/edge/spec/**',
376+
'exir/dialects/edge/test/**',
377+
'exir/dialects/test/**',
378+
'exir/emit/_emitter.py',
379+
'exir/emit/test/**',
380+
'exir/lowered_backend_module.py',
381+
'exir/memory_planning.py',
382+
'exir/operator/convert.py',
383+
'exir/operator/test/**',
384+
'exir/pass_base.py',
385+
'exir/passes/__init__.py',
386+
'exir/passes/_quant_patterns_and_replacements.py',
387+
'exir/passes/const_prop_pass.py',
388+
'exir/passes/constant_prop_pass.py',
389+
'exir/passes/dynamic_shape_prop_pass.py',
390+
'exir/passes/executorch_prim_ops_registry.py',
391+
'exir/passes/memory_planning_pass.py',
392+
'exir/passes/prune_empty_tensors_pass.py',
393+
'exir/passes/quant_fusion_pass.py',
394+
'exir/passes/quantize_io_pass.py',
395+
'exir/passes/remove_mixed_type_operators.py',
396+
'exir/passes/remove_noop_pass.py',
397+
'exir/passes/replace_view_copy_with_view_pass.py',
398+
'exir/passes/spec_prop_pass.py',
399+
'exir/passes/sym_shape_eval_pass.py',
400+
'exir/passes/sym_to_tensor_pass.py',
401+
'exir/passes/weights_to_outputs_pass.py',
402+
'exir/program/_program.py',
403+
'exir/program/test/**',
404+
'exir/serde/export_serialize.py',
405+
'exir/serde/serialize.py',
406+
'exir/tensor.py',
407+
'exir/tests/**',
408+
'exir/tracer.py',
409+
'exir/verification/arg_validator.py',
410+
'exir/verification/test/**',
411+
'exir/_serialize/test/**',
383412
]
384413
command = [
385414
'python3',

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,6 @@ follow_untyped_imports = True
103103

104104
[mypy-sympy.*]
105105
ignore_missing_imports = True
106+
107+
[mypy-executorch.exir.verification.bindings]
108+
ignore_missing_imports = True

exir/_serialize/_dataclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,5 @@ class Example
141141
if isinstance(T, enum.EnumMeta):
142142
data[key] = T[value]
143143
else:
144-
data[key] = T(value)
145-
return cls(**data)
144+
data[key] = T(value) # type: ignore[operator]
145+
return cls(**data) # type: ignore[operator]

exir/_serialize/_flatbuffer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def _run_flatc(args: Sequence[str]) -> None:
193193
subprocess.run([flatc_path] + list(args), check=True)
194194
else:
195195
# Expect the `flatc` tool to be on the system path or set as an env var.
196-
flatc_path = os.getenv("FLATC_EXECUTABLE")
197-
if not flatc_path:
198-
flatc_path = "flatc"
199-
subprocess.run([flatc_path] + list(args), check=True)
196+
flatc_executable = os.getenv("FLATC_EXECUTABLE")
197+
if not flatc_executable:
198+
flatc_executable = "flatc"
199+
subprocess.run([flatc_executable] + list(args), check=True)
200200

201201

202202
def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:

exir/_serialize/_named_data_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def _add_named_data_to_map(
121121
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
122122
raise ValueError(
123123
f"Duplicate key {key} with different data. "
124-
f"Existing data: {self.buffers[buffer_idx].buffer}. "
125-
f"New data: {data}."
124+
f"Existing data: {self.buffers[buffer_idx].buffer!r}. "
125+
f"New data: {data!r}." # type: ignore[str-bytes-safe]
126126
)
127127
self.buffers[buffer_idx].alignment = math.lcm(
128128
self.buffers[buffer_idx].alignment, alignment

exir/_serialize/_serialize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Set, Tuple
9+
from typing import Dict, List, Optional, Set, Tuple
1010

1111
from executorch.exir._serialize import _serialize_pte_binary
1212

@@ -102,10 +102,10 @@ def serialize_for_executorch(
102102
)
103103

104104
for tag in all_external_tags:
105-
buffers = []
105+
buffers: List[bytes] = []
106106
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
107107
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
108-
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
108+
fqn_to_index = emitter_output.external_constant_map.get(tag, {}) # type: ignore[union-attr]
109109
# Create a TensorEntry for each external tensor.
110110
for fqn, index in fqn_to_index.items():
111111
assert fqn in fqn_to_tensor_layout
@@ -118,13 +118,13 @@ def serialize_for_executorch(
118118
# Extract external data.
119119
key_to_data: Dict[str, DataEntry] = {}
120120
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
121-
key_to_buffer_index = named_data.external_data.get(tag, {})
121+
key_to_buffer_index = named_data.external_data.get(tag, {}) # type: ignore[union-attr]
122122
for key, index in key_to_buffer_index.items():
123123
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
124124
key_to_data[key] = DataEntry(
125-
len(buffers), named_data.buffers[index].alignment
125+
len(buffers), named_data.buffers[index].alignment # type: ignore[union-attr]
126126
)
127-
buffers.append(named_data.buffers[index].buffer)
127+
buffers.append(named_data.buffers[index].buffer) # type: ignore[union-attr]
128128

129129
# Serialize into PTD file.
130130
ptd_files[tag] = data_serializer.serialize(

exir/capture/_capture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
122122
outputs=[],
123123
# pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
124124
# `Union[Tensor, Module]`.
125-
in_spec=in_spec,
125+
in_spec=in_spec, # type: ignore[arg-type]
126126
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
127127
# `Union[Tensor, Module]`.
128-
out_spec=out_spec,
128+
out_spec=out_spec, # type: ignore[arg-type]
129129
),
130130
)
131131
],
@@ -207,7 +207,7 @@ def capture( # noqa: C901
207207
if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module):
208208
with patch_forward(f.__self__, f):
209209
ep = export(
210-
cast(torch.nn.Module, f.__self__),
210+
f.__self__, # type: ignore[redundant-cast]
211211
args,
212212
dynamic_shapes=dynamic_shapes,
213213
strict=True,
@@ -272,7 +272,7 @@ def graph_with_interpreter(*args):
272272
graph_with_interpreter,
273273
remove="mutations_and_views",
274274
)
275-
assert isinstance(functionalized_callable, Callable)
275+
assert callable(functionalized_callable) # type: ignore[arg-type]
276276

277277
if config.enable_dynamic_shape:
278278
fake_tensor_mode = FakeTensorMode(
@@ -357,7 +357,7 @@ def convert_to_fake(x):
357357
in_spec=in_spec,
358358
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
359359
# `Union[None, TreeSpec, Tensor, Module]`.
360-
out_spec=out_spec,
360+
out_spec=out_spec, # type: ignore[arg-type]
361361
),
362362
)
363363
],

exir/dialects/edge/_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ def to_out_variant(self) -> torch._ops.OpOverload:
317317
"""
318318

319319
# return if already found
320-
if "_out_variant" in self.__dict__ and self._out_variant:
321-
return self._out_variant
320+
if "_out_variant" in self.__dict__ and self._out_variant: # type: ignore[has-type]
321+
return self._out_variant # type: ignore[has-type]
322322
out_variant = to_variant(self._op, SchemaKind.out)
323323
self._out_variant = out_variant
324324
return out_variant
@@ -359,7 +359,7 @@ def __init__(
359359
self.__name__ = self._qualified_op_name.replace("::", ".")
360360
self._op = parent_overload_packet._op
361361
self._overload_names = parent_overload_packet._overload_names
362-
self._dir = []
362+
self._dir: List[str] = []
363363

364364
def __repr__(self):
365365
return "<EdgeOpOverloadPacket(op='{}', parent_op='{}')>".format(

exir/dialects/edge/arg/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self, argtype, argname, **kwargs):
181181
self._kw = True
182182

183183
@property
184-
def kw(self):
184+
def kw(self): # type: ignore[misc]
185185
return super().kw
186186

187187

exir/emit/_emit_program.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int
110110
found_param = True
111111
i += 1
112112
if len(fqns) > 0:
113-
training_metadata[fqn_method_prefix + name] = fqns
113+
training_metadata[fqn_method_prefix + name] = fqns # type: ignore[assignment]
114114
return training_metadata
115115

116116

@@ -139,7 +139,7 @@ def emit_program(
139139
methods = {"forward": methods}
140140

141141
# validation
142-
bad_methods = []
142+
bad_methods: List[str] = []
143143
for name, exported_program in methods.items():
144144
if not isinstance(exported_program, ExportedProgram):
145145
bad_methods.append(name)
@@ -153,6 +153,7 @@ def emit_program(
153153
debug_handle_map = {}
154154
method_to_delegate_debug_id_map = {}
155155
program_state = _ProgramState()
156+
emitter: Optional[_TopLevelEmitter] = None
156157

157158
# emit each entry point in order according to name.
158159
for name, exported_program in sorted(methods.items()):
@@ -183,14 +184,14 @@ def emit_program(
183184

184185
training_metadata = _get_training_metadata(methods)
185186
if len(training_metadata) > 0:
186-
plans.extend(emitter._emit_prim_getters(training_metadata))
187+
plans.extend(emitter._emit_prim_getters(training_metadata)) # type: ignore[union-attr]
187188

188189
# emit any primitive getters
189190
if prim_getters is not None:
190-
plans.extend(emitter._emit_prim_getters(prim_getters))
191+
plans.extend(emitter._emit_prim_getters(prim_getters)) # type: ignore[union-attr]
191192

192193
return EmitterOutput(
193-
debug_handle_map=debug_handle_map,
194+
debug_handle_map=debug_handle_map, # type: ignore[arg-type]
194195
method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
195196
program=Program(
196197
version=EXECUTORCH_SCHEMA_VERSION,

0 commit comments

Comments
 (0)