From 3645c7ec4d0923f9655e0e4da67cfc42df2c660e Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 9 Jun 2025 12:59:02 -0400 Subject: [PATCH 1/2] Typecheck 25% of exir directory --- .lintrunner.toml | 50 ++++++++++++++++++++++++++++++ .mypy.ini | 3 ++ exir/backend/backend_details.py | 6 ++-- exir/backend/partitioner.py | 4 +-- exir/common.py | 2 +- exir/control_flow.py | 6 ++-- exir/delegate.py | 8 +++-- exir/delegate.pyi | 2 +- exir/dialects/edge/dtype/runner.py | 18 ++++++++--- exir/graph_module.py | 4 +-- exir/operator/manip.py | 2 +- exir/serde/schema.py | 3 +- 12 files changed, 85 insertions(+), 23 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 8912e65d66d..fcf5b4a593f 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -315,6 +315,56 @@ include_patterns = [ # 'examples/**/*.py', 'examples/openvino/**/*.py', # 'exir/**/*.py', + # Phase 1: Start with simplest exir files (Batch 1) + 'exir/version.py', + 'exir/scalar_type.py', + 'exir/error.py', + 'exir/_warnings.py', + 'exir/types.py', + # Phase 1: Batch 2 - More utility files + 'exir/dynamic_shape.py', + 'exir/memory.py', + 'exir/dim_order_utils.py', + 'exir/wrap.py', + # Phase 1: Batch 3 - dialects subdirectory (5 files) + 'exir/dialects/__init__.py', + 'exir/dialects/_ops.py', + 'exir/dialects/backend/_ops.py', + 'exir/dialects/edge/dtype/supported.py', + 'exir/dialects/edge/dtype/utils.py', + # Phase 1: Batch 3+ - operator utility + 'exir/operator/util.py', + # Phase 1: Batch 4 - More subdirectories (6 files) + 'exir/program/__init__.py', + 'exir/program/_fake_program.py', + 'exir/emit/__init__.py', + 'exir/capture/__init__.py', + 'exir/capture/_config.py', + 'exir/verification/dev_html.py', + # Phase 1: Batch 5 - Fixed problematic files (3 files) + 'exir/operator/manip.py', + 'exir/dialects/edge/dtype/runner.py', + 'exir/serde/schema_check.py', + # Phase 1: Batch 6 - Final root-level fixes (3 files) + 'exir/common.py', + 'exir/sym_util.py', + 'exir/graph_module.py', + # Phase 1: Batch 7 - Clean files + fixed files (7 files) + 'exir/schema.py', + 'exir/print_program.py', + 'exir/pass_manager.py', + 'exir/graph.py', + 'exir/control_flow.py', + 'exir/delegate.py', + 'exir/backend/partitioner.py', + # Phase 1: Batch 8 - Clean files to reach 25% coverage (7 files) + 'exir/__init__.py', + 'exir/capture/_unlift.py', + 'exir/serde/__init__.py', + 'exir/serde/union.py', + 'exir/serde/schema.py', + 'exir/_serialize/__init__.py', + 'exir/_serialize/padding.py', # 'extension/**/*.py', 'kernels/**/*.py', 'profiler/**/*.py', diff --git a/.mypy.ini b/.mypy.ini index cd14cbac7ea..598da33371c 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -100,3 +100,6 @@ ignore_missing_imports = True [mypy-torchao.*] follow_untyped_imports = True + +[mypy-sympy.*] +ignore_missing_imports = True diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 513ae7c64b3..afabd1f3594 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from executorch.exir._serialize._named_data_store import NamedDataStoreOutput @@ -15,8 +15,8 @@ from torch.export.exported_program import ExportedProgram -def enforcedmethod(func): - func.__enforcedmethod__ = True +def enforcedmethod(func: Callable[..., Any]) -> Callable[..., Any]: + func.__enforcedmethod__ = True # type: ignore[attr-defined] return func diff --git a/exir/backend/partitioner.py b/exir/backend/partitioner.py index 68d5c246906..2eefbc1ab27 100644 --- a/exir/backend/partitioner.py +++ b/exir/backend/partitioner.py @@ -59,7 +59,7 @@ class Partitioner(ABC): def __init__( self, spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}), - ): + ) -> None: self._spec = spec def __call__(self, exported_program: ExportedProgram) -> PartitionResult: @@ -69,7 +69,7 @@ def __call__(self, exported_program: ExportedProgram) -> PartitionResult: def spec(self) -> Mapping[Union[str, int, float, bool], object]: return self._spec - @enforcedmethod + @enforcedmethod # type: ignore[misc] @abstractmethod def partition(self, exported_program: ExportedProgram) -> PartitionResult: """ diff --git a/exir/common.py b/exir/common.py index 98daac9a82c..fa42a2bd5dd 100644 --- a/exir/common.py +++ b/exir/common.py @@ -104,9 +104,9 @@ def override_logger( try: oldLevel = logging.root.level logging.root.setLevel(newLevel) + oldFormatters = [] if fmtstr: newformatter = logging.Formatter(fmtstr, None, "%") - oldFormatters = [] for handler in logging.root.handlers: oldFormatters.append(handler.formatter) handler.formatter = newformatter diff --git a/exir/control_flow.py b/exir/control_flow.py index ff8016bd23e..64c0a8e9bbb 100644 --- a/exir/control_flow.py +++ b/exir/control_flow.py @@ -103,7 +103,7 @@ def _make_submodule( f"Expect function '{fn.__name__}' to be decorated with tracing_context.", ) # pyre-ignore - args = fn.__tracing_inputs__ + args = fn.__tracing_inputs__ # type: ignore[attr-defined] # TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False) output = next(iter(reversed(gm.graph.nodes))) @@ -122,7 +122,7 @@ def _make_submodule( output.args = tuple(output.args[0]) gm.recompile() # pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`. - gm.__tracing_inputs__ = args + gm.__tracing_inputs__ = args # type: ignore[attr-defined] return gm @@ -198,7 +198,7 @@ def wrapper( return f(*args) - wrapper.__tracing_inputs__ = inputs # pyre-ignore + wrapper.__tracing_inputs__ = inputs # type: ignore[attr-defined] return wrapper return decorator diff --git a/exir/delegate.py b/exir/delegate.py index 694ea6fa32f..c36fdc3f33b 100644 --- a/exir/delegate.py +++ b/exir/delegate.py @@ -42,9 +42,11 @@ LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" # pyre-ignore - def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): + def trace_call_delegate( + proxy_mode: Any, func_overload: Any, lowered_module: Any, *args: Any + ) -> Any: # pyre-ignore - def _unwrap_proxy(e): + def _unwrap_proxy(e: Any) -> Any: if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): return e return get_proxy_slot( @@ -151,7 +153,7 @@ def is_lowered_module(obj: Any) -> bool: def get_lowered_module_name( root: torch.nn.Module, # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. - lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa + lowered_module: Any, # noqa ) -> str: """ Adds the given lowered_module into the given root module and returns the diff --git a/exir/delegate.pyi b/exir/delegate.pyi index 8a2a7d16b9f..31fdfebd267 100644 --- a/exir/delegate.pyi +++ b/exir/delegate.pyi @@ -17,5 +17,5 @@ def is_lowered_module(obj: Any) -> bool: ... def get_lowered_module_name( root: torch.nn.Module, # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. - lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa + lowered_module: Any, # noqa ) -> str: ... diff --git a/exir/dialects/edge/dtype/runner.py b/exir/dialects/edge/dtype/runner.py index 67982a164e2..497382c9023 100644 --- a/exir/dialects/edge/dtype/runner.py +++ b/exir/dialects/edge/dtype/runner.py @@ -30,7 +30,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]: @staticmethod def _get_args_kwargs( inputs: Dict[str, List[BaseArg]], - dtypes: Tuple[Optional[torch.dtype]], + dtypes: Tuple[Optional[torch.dtype], ...], mode: ArgMode, ) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]: """Construct args and kwargs for op given dtypes.""" @@ -71,16 +71,20 @@ def run_dtypes( self, name: str, inputs: Dict[str, List[BaseArg]], - dtypes: Tuple[Optional[torch.dtype]], + dtypes: Tuple[Optional[torch.dtype], ...], argmode: ArgMode = ArgMode.RANDOM, ) -> Tuple[ - bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg] + bool, + str, + Tuple[Optional[torch.dtype], ...], + List[BaseArg], + Dict[str, BaseKwarg], ]: args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode) op = get_callable(name) try: res = op(*args, **kwargs) - ret_dtypes = () + ret_dtypes: Tuple[torch.dtype, ...] = () if "returns" in inputs: ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"])) return (True, name, dtypes + ret_dtypes, args, kwargs) @@ -112,7 +116,11 @@ def run( argmode: ArgMode = ArgMode.ONES, ) -> List[ Tuple[ - bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg] + bool, + str, + Tuple[Optional[torch.dtype], ...], + List[BaseArg], + Dict[str, BaseKwarg], ] ]: results = [] diff --git a/exir/graph_module.py b/exir/graph_module.py index e26d22d8145..18afce81a4f 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -42,7 +42,7 @@ def _get_submodule( assert submod_node.op == "get_attr" assert isinstance(submod_node.target, str) submodule = graph_module.get_submodule(submod_node.target) - # pyre-ignore + assert isinstance(submodule, torch.nn.Module) return submod_node.target, submodule, node @@ -67,7 +67,7 @@ def get_control_flow_submodules( if node.target is torch.ops.higher_order.map_impl: control_flow_submodules.append(_get_submodule(graph_module, node, 0)) - return control_flow_submodules + return control_flow_submodules # type: ignore[return-value] def bfs_trace_with_node_process( diff --git a/exir/operator/manip.py b/exir/operator/manip.py index 8c27b12bd0b..a8807cd68cd 100644 --- a/exir/operator/manip.py +++ b/exir/operator/manip.py @@ -68,7 +68,7 @@ def wrapper(*args: TensorSpec, **kwargs: TensorSpec) -> Dict[str, TensorSpec]: def wrapper(get_scratch_metas_fn: ScratchCallableType) -> ScratchCallableType: # pyre-fixme[16]: `OpOverload` has no attribute `get_scratch_metas`. - out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn) + out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn) # type: ignore[attr-defined] return get_scratch_metas_fn return wrapper diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 6d250ee7923..c3263795b7d 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -383,9 +383,8 @@ class ExportedProgram: opset_version: Dict[str, int] range_constraints: Dict[str, RangeConstraint] schema_version: SchemaVersion - dialect: str verifiers: List[str] = field(default_factory=list) - dialect: str = "" # TODO deprecated + dialect: str = "" # TODO deprecated @dataclass From 27a1da6fa718d7611d3fe3a1d2c6b6d9881dbda6 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 9 Jun 2025 16:02:15 -0400 Subject: [PATCH 2/2] Typecheck 42% of exir directory --- .lintrunner.toml | 96 +++++++++---------- .mypy.ini | 3 + exir/_serialize/_dataclass.py | 4 +- exir/_serialize/_flatbuffer.py | 8 +- exir/_serialize/_named_data_store.py | 4 +- exir/_serialize/_serialize.py | 12 +-- exir/backend/backend_api.py | 16 ++-- .../config_partitioner.py | 7 +- .../duplicate_constant_node_pass.py | 4 +- exir/capture/_capture.py | 10 +- exir/dialects/edge/_ops.py | 6 +- exir/dialects/edge/arg/model.py | 2 +- exir/emit/_emit_program.py | 11 ++- exir/operator/convert.py | 33 ++++--- exir/program/_program.py | 32 +++---- exir/verification/arg_validator.py | 8 +- exir/verification/interpreter.py | 14 +-- exir/verification/verifier.py | 8 +- 18 files changed, 136 insertions(+), 142 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index fcf5b4a593f..418256a8fbf 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -312,59 +312,9 @@ include_patterns = [ # 'devtools/**/*.py', 'devtools/visualization/**/*.py', 'docs/**/*.py', + 'exir/**/*.py', # 'examples/**/*.py', 'examples/openvino/**/*.py', - # 'exir/**/*.py', - # Phase 1: Start with simplest exir files (Batch 1) - 'exir/version.py', - 'exir/scalar_type.py', - 'exir/error.py', - 'exir/_warnings.py', - 'exir/types.py', - # Phase 1: Batch 2 - More utility files - 'exir/dynamic_shape.py', - 'exir/memory.py', - 'exir/dim_order_utils.py', - 'exir/wrap.py', - # Phase 1: Batch 3 - dialects subdirectory (5 files) - 'exir/dialects/__init__.py', - 'exir/dialects/_ops.py', - 'exir/dialects/backend/_ops.py', - 'exir/dialects/edge/dtype/supported.py', - 'exir/dialects/edge/dtype/utils.py', - # Phase 1: Batch 3+ - operator utility - 'exir/operator/util.py', - # Phase 1: Batch 4 - More subdirectories (6 files) - 'exir/program/__init__.py', - 'exir/program/_fake_program.py', - 'exir/emit/__init__.py', - 'exir/capture/__init__.py', - 'exir/capture/_config.py', - 'exir/verification/dev_html.py', - # Phase 1: Batch 5 - Fixed problematic files (3 files) - 'exir/operator/manip.py', - 'exir/dialects/edge/dtype/runner.py', - 'exir/serde/schema_check.py', - # Phase 1: Batch 6 - Final root-level fixes (3 files) - 'exir/common.py', - 'exir/sym_util.py', - 'exir/graph_module.py', - # Phase 1: Batch 7 - Clean files + fixed files (7 files) - 'exir/schema.py', - 'exir/print_program.py', - 'exir/pass_manager.py', - 'exir/graph.py', - 'exir/control_flow.py', - 'exir/delegate.py', - 'exir/backend/partitioner.py', - # Phase 1: Batch 8 - Clean files to reach 25% coverage (7 files) - 'exir/__init__.py', - 'exir/capture/_unlift.py', - 'exir/serde/__init__.py', - 'exir/serde/union.py', - 'exir/serde/schema.py', - 'exir/_serialize/__init__.py', - 'exir/_serialize/padding.py', # 'extension/**/*.py', 'kernels/**/*.py', 'profiler/**/*.py', @@ -377,9 +327,49 @@ include_patterns = [ exclude_patterns = [ 'third-party/**', '**/third-party/**', - 'scripts/check_binary_dependencies.py', - 'profiler/test/test_profiler_e2e.py', 'backends/arm/test/**', + # exir exclusions (sorted alphabetically) + 'exir/_serialize/test/**', + 'exir/backend/test/**', + 'exir/backend/utils.py', + 'exir/dialects/backend/test/**', + 'exir/dialects/edge/arg/model.py', + 'exir/dialects/edge/op/test/**', + 'exir/dialects/edge/spec/**', + 'exir/dialects/edge/test/**', + 'exir/dialects/test/**', + 'exir/emit/_emitter.py', + 'exir/emit/test/**', + 'exir/lowered_backend_module.py', + 'exir/memory_planning.py', + 'exir/operator/test/**', + 'exir/pass_base.py', + 'exir/passes/__init__.py', + 'exir/passes/_quant_patterns_and_replacements.py', + 'exir/passes/const_prop_pass.py', + 'exir/passes/constant_prop_pass.py', + 'exir/passes/dynamic_shape_prop_pass.py', + 'exir/passes/executorch_prim_ops_registry.py', + 'exir/passes/memory_planning_pass.py', + 'exir/passes/prune_empty_tensors_pass.py', + 'exir/passes/quant_fusion_pass.py', + 'exir/passes/quantize_io_pass.py', + 'exir/passes/remove_mixed_type_operators.py', + 'exir/passes/remove_noop_pass.py', + 'exir/passes/replace_view_copy_with_view_pass.py', + 'exir/passes/spec_prop_pass.py', + 'exir/passes/sym_shape_eval_pass.py', + 'exir/passes/sym_to_tensor_pass.py', + 'exir/passes/weights_to_outputs_pass.py', + 'exir/program/test/**', + 'exir/serde/export_serialize.py', + 'exir/serde/serialize.py', + 'exir/tensor.py', + 'exir/tests/**', + 'exir/tracer.py', + 'exir/verification/test/**', + 'profiler/test/test_profiler_e2e.py', + 'scripts/check_binary_dependencies.py', ] command = [ 'python3', diff --git a/.mypy.ini b/.mypy.ini index 598da33371c..61dc777d3a9 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -103,3 +103,6 @@ follow_untyped_imports = True [mypy-sympy.*] ignore_missing_imports = True + +[mypy-executorch.exir.verification.bindings] +ignore_missing_imports = True diff --git a/exir/_serialize/_dataclass.py b/exir/_serialize/_dataclass.py index 013d733bcda..04e58f96e0a 100644 --- a/exir/_serialize/_dataclass.py +++ b/exir/_serialize/_dataclass.py @@ -141,5 +141,5 @@ class Example if isinstance(T, enum.EnumMeta): data[key] = T[value] else: - data[key] = T(value) - return cls(**data) + data[key] = T(value) # type: ignore[operator] + return cls(**data) # type: ignore[operator] diff --git a/exir/_serialize/_flatbuffer.py b/exir/_serialize/_flatbuffer.py index 4599249f00c..5226de0bd9e 100644 --- a/exir/_serialize/_flatbuffer.py +++ b/exir/_serialize/_flatbuffer.py @@ -193,10 +193,10 @@ def _run_flatc(args: Sequence[str]) -> None: subprocess.run([flatc_path] + list(args), check=True) else: # Expect the `flatc` tool to be on the system path or set as an env var. - flatc_path = os.getenv("FLATC_EXECUTABLE") - if not flatc_path: - flatc_path = "flatc" - subprocess.run([flatc_path] + list(args), check=True) + flatc_executable = os.getenv("FLATC_EXECUTABLE") + if not flatc_executable: + flatc_executable = "flatc" + subprocess.run([flatc_executable] + list(args), check=True) def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None: diff --git a/exir/_serialize/_named_data_store.py b/exir/_serialize/_named_data_store.py index 2c2d975937e..12f627d5744 100644 --- a/exir/_serialize/_named_data_store.py +++ b/exir/_serialize/_named_data_store.py @@ -121,8 +121,8 @@ def _add_named_data_to_map( if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx: raise ValueError( f"Duplicate key {key} with different data. " - f"Existing data: {self.buffers[buffer_idx].buffer}. " - f"New data: {data}." + f"Existing data: {self.buffers[buffer_idx].buffer!r}. " + f"New data: {data!r}." # type: ignore[str-bytes-safe] ) self.buffers[buffer_idx].alignment = math.lcm( self.buffers[buffer_idx].alignment, alignment diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 1b36dac1743..eb748eb9ab8 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Dict, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from executorch.exir._serialize import _serialize_pte_binary @@ -102,10 +102,10 @@ def serialize_for_executorch( ) for tag in all_external_tags: - buffers = [] + buffers: List[bytes] = [] fqn_to_tensor_entry: Dict[str, TensorEntry] = {} # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`. - fqn_to_index = emitter_output.external_constant_map.get(tag, {}) + fqn_to_index = emitter_output.external_constant_map.get(tag, {}) # type: ignore[union-attr] # Create a TensorEntry for each external tensor. for fqn, index in fqn_to_index.items(): assert fqn in fqn_to_tensor_layout @@ -118,13 +118,13 @@ def serialize_for_executorch( # Extract external data. key_to_data: Dict[str, DataEntry] = {} # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`. - key_to_buffer_index = named_data.external_data.get(tag, {}) + key_to_buffer_index = named_data.external_data.get(tag, {}) # type: ignore[union-attr] for key, index in key_to_buffer_index.items(): # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`. key_to_data[key] = DataEntry( - len(buffers), named_data.buffers[index].alignment + len(buffers), named_data.buffers[index].alignment # type: ignore[union-attr] ) - buffers.append(named_data.buffers[index].buffer) + buffers.append(named_data.buffers[index].buffer) # type: ignore[union-attr] # Serialize into PTD file. ptd_files[tag] = data_serializer.serialize( diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 91df0409051..41abd654454 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -123,7 +123,7 @@ def to_backend( compile_specs=compile_specs, named_data_store_output=preprocess_result.data_store_output, ) - lowered_module.meta = { + lowered_module.meta = { # type: ignore[assignment] "debug_handle_map": preprocess_result.debug_handle_map } return lowered_module @@ -311,7 +311,7 @@ def _partition_and_lower_one_graph_module( is_submodule, ) - lowered_submodule = to_backend( + lowered_submodule = to_backend( # type: ignore[call-arg] delegation_spec.backend_id, submodule_program, delegation_spec.compile_specs, @@ -449,7 +449,7 @@ def _create_partitions_in_graph_module( owning_program: ExportedProgram, is_submodule: bool, ) -> Dict[str, List[torch.fx.Node]]: - backend_id_to_submodule_name = {} + backend_id_to_submodule_name: Dict[str, List[str]] = {} for tag, delegation_spec in partition_result.partition_tags.items(): # Create partition with nodes containing this tag. There should only be # one contained submodule per tag @@ -517,10 +517,12 @@ def _create_partitions_in_graph_module( # in future edits to the graph. As a result, we just keep track of the node's name # and at the end we search for this node in our final graph module backend_id_to_submodule_name[delegation_spec.backend_id].append( - call_module_node.target + call_module_node.target # type: ignore[arg-type] ) - created_submodule_nodes = {key: [] for key in backend_id_to_submodule_name.keys()} + created_submodule_nodes: Dict[str, List[torch.fx.Node]] = { + key: [] for key in backend_id_to_submodule_name.keys() + } for backend_id, submodule_name in backend_id_to_submodule_name.items(): for node in tagged_graph_module.graph.nodes: if node.op == "call_module" and node.target in submodule_name: @@ -615,7 +617,7 @@ def lower_all_submodules_to_backend( compile_specs=compile_spec, named_data_store_output=preprocess_result.data_store_output, ) - lowered_module.meta = { + lowered_module.meta = { # type: ignore[assignment] "debug_handle_map": preprocess_result.debug_handle_map, } is_submodule = call_submodule_node.meta["is_submodule"] @@ -698,7 +700,7 @@ def to_backend( method_to_partitioner = method_edge_program_partitioners.method_to_partitioner partitioned_and_lowered_exported_programs = {} - backend_id_to_method_submodules_map = {} + backend_id_to_method_submodules_map: Dict[str, Dict[str, List[torch.fx.Node]]] = {} method_to_tagged_exported_program = {} for method_name, partitioner_instance in method_to_partitioner.items(): diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py index 1a9bcc33e80..b7903435e42 100644 --- a/exir/backend/canonical_partitioners/config_partitioner.py +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -52,10 +52,9 @@ class PartitionerConfig(ABC): the specified backend. """ - @classmethod - @property + @property # type: ignore[misc] @abstractmethod - def target_name(cls) -> str: + def target_name(self) -> str: """ Target name for this partitioner config. When the Config-Based Partitioner encounters a node with a matching target name, it uses this config's methods to @@ -138,7 +137,7 @@ def filter_fn(node: torch.fx.Node) -> bool: """ if node.op != "call_function": return False - target_name = format_target_name(node.target.__name__) # pyre-ignore + target_name = format_target_name(node.target.__name__) # type: ignore[union-attr] if target_name in self.target_partitioner_configs: config = self.target_partitioner_configs[target_name] diff --git a/exir/backend/canonical_partitioners/duplicate_constant_node_pass.py b/exir/backend/canonical_partitioners/duplicate_constant_node_pass.py index 961bd741205..00fb32175fe 100644 --- a/exir/backend/canonical_partitioners/duplicate_constant_node_pass.py +++ b/exir/backend/canonical_partitioners/duplicate_constant_node_pass.py @@ -28,9 +28,9 @@ def _get_attribute_or_constants( if maybe_param is not None: constant_or_attribute = maybe_param elif maybe_buffer is not None: - constant_or_attribute = maybe_buffer + constant_or_attribute = maybe_buffer # type: ignore[assignment] elif maybe_lifted_tensor is not None: - constant_or_attribute = maybe_lifted_tensor + constant_or_attribute = maybe_lifted_tensor # type: ignore[assignment] return constant_or_attribute diff --git a/exir/capture/_capture.py b/exir/capture/_capture.py index 975191f0744..c1ea39bd12c 100644 --- a/exir/capture/_capture.py +++ b/exir/capture/_capture.py @@ -122,10 +122,10 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram: outputs=[], # pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got # `Union[Tensor, Module]`. - in_spec=in_spec, + in_spec=in_spec, # type: ignore[arg-type] # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got # `Union[Tensor, Module]`. - out_spec=out_spec, + out_spec=out_spec, # type: ignore[arg-type] ), ) ], @@ -207,7 +207,7 @@ def capture( # noqa: C901 if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module): with patch_forward(f.__self__, f): ep = export( - cast(torch.nn.Module, f.__self__), + f.__self__, # type: ignore[redundant-cast] args, dynamic_shapes=dynamic_shapes, strict=True, @@ -272,7 +272,7 @@ def graph_with_interpreter(*args): graph_with_interpreter, remove="mutations_and_views", ) - assert isinstance(functionalized_callable, Callable) + assert callable(functionalized_callable) # type: ignore[arg-type] if config.enable_dynamic_shape: fake_tensor_mode = FakeTensorMode( @@ -357,7 +357,7 @@ def convert_to_fake(x): in_spec=in_spec, # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got # `Union[None, TreeSpec, Tensor, Module]`. - out_spec=out_spec, + out_spec=out_spec, # type: ignore[arg-type] ), ) ], diff --git a/exir/dialects/edge/_ops.py b/exir/dialects/edge/_ops.py index 1915bf3d318..fcb6b973deb 100644 --- a/exir/dialects/edge/_ops.py +++ b/exir/dialects/edge/_ops.py @@ -317,8 +317,8 @@ def to_out_variant(self) -> torch._ops.OpOverload: """ # return if already found - if "_out_variant" in self.__dict__ and self._out_variant: - return self._out_variant + if "_out_variant" in self.__dict__ and self._out_variant: # type: ignore[has-type] + return self._out_variant # type: ignore[has-type] out_variant = to_variant(self._op, SchemaKind.out) self._out_variant = out_variant return out_variant @@ -359,7 +359,7 @@ def __init__( self.__name__ = self._qualified_op_name.replace("::", ".") self._op = parent_overload_packet._op self._overload_names = parent_overload_packet._overload_names - self._dir = [] + self._dir: List[str] = [] def __repr__(self): return "".format( diff --git a/exir/dialects/edge/arg/model.py b/exir/dialects/edge/arg/model.py index b7a81f62c6d..1aafde83d71 100644 --- a/exir/dialects/edge/arg/model.py +++ b/exir/dialects/edge/arg/model.py @@ -181,7 +181,7 @@ def __init__(self, argtype, argname, **kwargs): self._kw = True @property - def kw(self): + def kw(self): # type: ignore[misc] return super().kw diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index f456626feed..eaf152cfac8 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -110,7 +110,7 @@ def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int found_param = True i += 1 if len(fqns) > 0: - training_metadata[fqn_method_prefix + name] = fqns + training_metadata[fqn_method_prefix + name] = fqns # type: ignore[assignment] return training_metadata @@ -139,7 +139,7 @@ def emit_program( methods = {"forward": methods} # validation - bad_methods = [] + bad_methods: List[str] = [] for name, exported_program in methods.items(): if not isinstance(exported_program, ExportedProgram): bad_methods.append(name) @@ -153,6 +153,7 @@ def emit_program( debug_handle_map = {} method_to_delegate_debug_id_map = {} program_state = _ProgramState() + emitter: Optional[_TopLevelEmitter] = None # emit each entry point in order according to name. for name, exported_program in sorted(methods.items()): @@ -183,14 +184,14 @@ def emit_program( training_metadata = _get_training_metadata(methods) if len(training_metadata) > 0: - plans.extend(emitter._emit_prim_getters(training_metadata)) + plans.extend(emitter._emit_prim_getters(training_metadata)) # type: ignore[union-attr] # emit any primitive getters if prim_getters is not None: - plans.extend(emitter._emit_prim_getters(prim_getters)) + plans.extend(emitter._emit_prim_getters(prim_getters)) # type: ignore[union-attr] return EmitterOutput( - debug_handle_map=debug_handle_map, + debug_handle_map=debug_handle_map, # type: ignore[arg-type] method_to_delegate_debug_id_map=method_to_delegate_debug_id_map, program=Program( version=EXECUTORCH_SCHEMA_VERSION, diff --git a/exir/operator/convert.py b/exir/operator/convert.py index 74bd686c542..e7d693cf1bc 100644 --- a/exir/operator/convert.py +++ b/exir/operator/convert.py @@ -86,12 +86,12 @@ def _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]: native_schema = _op_overload_to_schema_cache.get(op_overload) if not native_schema: native_schema = _pybind_schema_to_native_schema(op_overload._schema) - _op_overload_to_schema_cache[op_overload] = native_schema # pyre-ignore + _op_overload_to_schema_cache[op_overload] = native_schema # type: ignore[assignment] return native_schema def get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]: - return get_out_args_from_schema(_get_overload_schema(op_overload)) # pyre-ignore + return get_out_args_from_schema(_get_overload_schema(op_overload)) # type: ignore[arg-type] def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]: @@ -102,7 +102,7 @@ def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]: assert ( out_var_schema.is_out_fn() ), f"Expect an out variant, but get: {out_var_schema}" - return tuple(arg.name for arg in out_var_schema.arguments.out) + return tuple(arg.name for arg in out_var_schema.arguments.out) # type: ignore[return-value] def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]: @@ -113,7 +113,7 @@ def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]: ns_and_opname = qualified_opname.split("::") if len(ns_and_opname) != 2: raise RuntimeError(f"Invalid qualified_opname {qualified_opname}") - return tuple(ns_and_opname) + return tuple(ns_and_opname) # type: ignore[return-value] def get_op_overload(qualified_opname: str, overload: str) -> OpOverload: @@ -146,19 +146,19 @@ def set_mapping_for_op(op: OpOverload) -> None: """ native_schema = _pybind_schema_to_native_schema(op._schema) # pyre-fixme[16]: `Optional` has no attribute `kind`. - assert native_schema.kind() in ( + assert native_schema.kind() in ( # type: ignore[union-attr] SchemaKind.functional, SchemaKind.out, SchemaKind.mutable, ) assert not ( - native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map + native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map # type: ignore[union-attr] ) assert not ( - native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map + native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map # type: ignore[union-attr] ) assert not ( - native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map + native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map # type: ignore[union-attr] ) qualified_opname = str(op._schema.name) @@ -186,8 +186,7 @@ def set_mapping_for_op(op: OpOverload) -> None: signature = dataclasses.replace(signature, returns=()) kind = schema.kind() - # pyre-fixme[6]: For 1st argument expected `str` but got `FunctionSchema`. - group_by_kind = group_by_signature.setdefault(signature, {}) + group_by_kind = group_by_signature.setdefault(signature, {}) # type: ignore[arg-type] assert ( kind not in group_by_kind ), f"Schema of kind {kind} already exist for {schema}" @@ -237,14 +236,14 @@ def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]: arguments. """ schema = _get_overload_schema(op_overload) - if schema.is_out_fn(): # pyre-ignore - return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6] + if schema.is_out_fn(): # type: ignore[union-attr] + return op_overload, get_out_args_from_schema(schema) # type: ignore[arg-type] # should be a functionalish op here assert ( - schema.kind() == SchemaKind.functional # pyre-ignore[16] - or schema.kind() == SchemaKind.mutable - ), f"Expect a functionalish op, but get {schema.kind()} {schema}" + schema.kind() == SchemaKind.functional # type: ignore[union-attr] + or schema.kind() == SchemaKind.mutable # type: ignore[union-attr] + ), f"Expect a functionalish op, but get {schema.kind()} {schema}" # type: ignore[union-attr] if ( op_overload not in _func_to_out_variant_map @@ -279,8 +278,8 @@ def to_scratch_op(op_overload: OpOverload) -> Optional[OpOverload]: # pass. Return immediately rather than throwing an exception since the user must have ignores # errors for some reason (e.g. desigin some special unit tests, or unblock new # use cases). - if schema.kind() != SchemaKind.out: # pyre-ignore - logging.debug(f"Expect an out variant op as input, got: {schema.kind()}") + if schema.kind() != SchemaKind.out: # type: ignore[union-attr] + logging.debug(f"Expect an out variant op as input, got: {schema.kind()}") # type: ignore[union-attr] return None if op_overload not in _out_variant_to_scratch_map: diff --git a/exir/program/_program.py b/exir/program/_program.py index b9fa83a668f..339e3f728bc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -104,7 +104,7 @@ from torch.library import Library try: - from executorch.exir.program.fb.logger import et_logger + from executorch.exir.program.fb.logger import et_logger # type: ignore[import-not-found] except ImportError: # Define a stub decorator that does nothing def et_logger(api_name: str) -> Callable[[Any], Any]: @@ -122,9 +122,9 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" lib = Library(edge_no_decomp_namespace, "DEF") # Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace. -aten_op_to_transform_op = {} +aten_op_to_transform_op: Dict[Any, Any] = {} # Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops. -transform_op_to_aten_op = {} +transform_op_to_aten_op: Dict[Any, Any] = {} def _get_updated_range_constraints(gm): @@ -181,7 +181,7 @@ def _get_updated_graph_signature( old_input_spec.arg if isinstance(old_input_spec.arg, ConstantArgument) # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_input_spec.arg)(node.name) + else type(old_input_spec.arg)(node.name) # type: ignore[call-arg] ) new_input_specs.append( InputSpec( @@ -206,7 +206,7 @@ def _get_updated_graph_signature( old_output_spec.arg if isinstance(old_output_spec.arg, ConstantArgument) # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_output_spec.arg)(node.name) + else type(old_output_spec.arg)(node.name) # type: ignore[call-arg] ) new_output_specs.append( OutputSpec(old_output_spec.kind, arg, old_output_spec.target) @@ -502,7 +502,7 @@ def to_executorch( # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. # After exir.capture is gone I will clean up the memory planning infra to be consistent. # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] + new_gm_res = config.memory_planning_pass(new_gm) # type: ignore[operator] assert new_gm_res is not None new_gm = new_gm_res.graph_module new_prog = ExirExportedProgram( @@ -750,7 +750,7 @@ def pre_memory_planning_passes( if not name: sym_shape_eval_pass = default_pass # pyre-ignore: Undefined attribute [16] - sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass) + sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass) # type: ignore[arg-type] elif isinstance(config.sym_shape_eval_pass, PassBase): sym_shape_eval_pass = config.sym_shape_eval_pass else: @@ -780,7 +780,7 @@ def edge_to_executorch_passes( Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ passes: List[PassType] = [ - *config.passes, + *config.passes, # type: ignore[assignment] SpecPropPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if @@ -827,9 +827,9 @@ def _generate_edge_program( ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture passes.extend(pre_op_replace_passes) if config._use_edge_ops: - passes.append(OpReplacePass()) + passes.append(OpReplacePass()) # type: ignore[arg-type] if not config._skip_dim_order: - passes.append(MemoryFormatOpsPass()) + passes.append(MemoryFormatOpsPass()) # type: ignore[arg-type] for p in passes: gm_res = p(gm) @@ -1176,7 +1176,7 @@ def collect_named_data_store_outputs( collect_named_data_store_outputs(exported_program.graph_module) -@et_logger("to_edge_transform_and_lower") +@et_logger("to_edge_transform_and_lower") # type: ignore[misc] def to_edge_transform_and_lower( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ @@ -1337,7 +1337,7 @@ def to_edge_with_preserved_ops( ) -@et_logger("to_edge") +@et_logger("to_edge") # type: ignore[misc] def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, @@ -1441,7 +1441,7 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: return self._edge_programs[method_name] - @et_logger("transform") + @et_logger("transform") # type: ignore[misc] def transform( self, passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], @@ -1489,7 +1489,7 @@ def transform( new_programs, copy.deepcopy(self._config_methods), compile_config ) - @et_logger("to_backend") + @et_logger("to_backend") # type: ignore[misc] def to_backend( self, partitioner: Union[Partitioner, Dict[str, Partitioner]], @@ -1537,7 +1537,7 @@ def to_backend( config, ) - @et_logger("to_executorch") + @et_logger("to_executorch") # type: ignore[misc] def to_executorch( self, config: Optional[ExecutorchBackendConfig] = None, @@ -1604,7 +1604,7 @@ def to_executorch( new_gm, new_signature ) else: - new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] + new_gm_res = memory_planning_pass(new_gm) # type: ignore[operator] # WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS. # THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER. diff --git a/exir/verification/arg_validator.py b/exir/verification/arg_validator.py index c02e0dfa507..a89806a00e0 100644 --- a/exir/verification/arg_validator.py +++ b/exir/verification/arg_validator.py @@ -39,7 +39,7 @@ def __init__(self, graph_module: torch.fx.GraphModule) -> None: super().__init__(graph_module) self.violating_ops: Dict[ EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node] - ] = defaultdict(dict) + ] = defaultdict(dict) # type: ignore[arg-type] def run_node(self, n: torch.fx.Node) -> None: self.node = n @@ -63,7 +63,7 @@ def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs): return kernel_arg def call_function( # noqa: C901 # pyre-fixme[14] - self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] + self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] # type: ignore[override] ) -> Any: """ Go through all the node.target and validate their Tensor arguments are having the allowed dtypes. @@ -73,7 +73,7 @@ def call_function( # noqa: C901 # pyre-fixme[14] ): if isinstance(target, HigherOrderOperator): raise RunHigherOrderOperatorError("Can't run delegate") - return super().call_function(target, args, kwargs) # pyre-fixme[6] + return super().call_function(target, args, kwargs) # type: ignore[arg-type] # TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist. tensor_arg_types: Dict[str, Optional[torch.dtype]] = {} @@ -137,4 +137,4 @@ def call_function( # noqa: C901 # pyre-fixme[14] valid = target._schema.dtype_constraint.validate(tensor_arg_types) if not valid: self.violating_ops[target] = (tensor_arg_types, self.node) - return super().call_function(target, args, kwargs) # pyre-fixme[6] + return super().call_function(target, args, kwargs) # type: ignore[arg-type] diff --git a/exir/verification/interpreter.py b/exir/verification/interpreter.py index fff6a6d79bc..da290f7b428 100644 --- a/exir/verification/interpreter.py +++ b/exir/verification/interpreter.py @@ -7,10 +7,10 @@ # pyre-strict import copy -from typing import List, Optional, Union +from typing import Any, List, Optional, Union # pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`. -import executorch.exir.verification.bindings as bindings # @manual=//executorch/exir/verification:bindings +import executorch.exir.verification.bindings as bindings # @manual=//executorch/exir/verification:bindings # type: ignore[import-not-found] import executorch.extension.pytree as ex_pytree import torch @@ -227,17 +227,17 @@ def load_from_value_list(self, idx: int) -> None: # noqa assert isinstance(self.execution_plan.values[item].val, Int) # pyre-fixme [16] Undefined attribute [16]: Item `Bool` has no # attribute `int_val`. - unboxed_list.append(self.execution_plan.values[item].val.int_val) + unboxed_list.append(self.execution_plan.values[item].val.int_val) # type: ignore[union-attr] self._value_list[idx] = unboxed_list elif isinstance(val, (TensorList, OptionalTensorList)): - tensor_list = [] + tensor_list: List[Any] = [] for i in val.items: if i == -1: tensor_list.append(None) continue self.load_value(i) tensor_list.append(self._value_list[i]) - self._value_list[idx] = tensor_list + self._value_list[idx] = tensor_list # type: ignore[assignment] elif isinstance(val, Tensor): if val.data_buffer_idx == 0: # TODO(zhengxu) Verify that argument is actually an out variant @@ -300,7 +300,7 @@ def set_value(self, idx: int, input_val: ValueType) -> None: val_idx = val.items[i] self._value_list[val_idx] = input_val[i] tensor_list.append(input_val[i]) - self._value_list[idx] = tensor_list + self._value_list[idx] = tensor_list # type: ignore[assignment] else: raise TypeError( f"Unexpected type, {type(val)}, with value, {val}, in Execution Plan values." @@ -401,7 +401,7 @@ def run(self, *raw_args: torch.Tensor) -> PyTree: ip = ( ip + 1 # pyre-ignore - if self._value_list[instruction.instr_args.cond_val_index] + if self._value_list[instruction.instr_args.cond_value_index] # type: ignore[attr-defined] # pyre-ignore else instruction.instr_args.destination_instruction ) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index bc510ff6849..fb487601ec1 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -142,7 +142,7 @@ def check_valid_op(self, op): ret = _EXIRATenDialectVerifier if not class_only: - ret = ret() + ret = ret() # type: ignore[assignment] return ret @@ -234,9 +234,9 @@ def __init__(self) -> None: self.check_valid_aten_op = self.aten_op_verifier.check_valid_op if self.check_edge_ops: - self.check_valid_op = self.check_valid_edge_op + self.check_valid_op = self.check_valid_edge_op # type: ignore[method-assign] else: - self.check_valid_op = self.check_valid_aten_op + self.check_valid_op = self.check_valid_aten_op # type: ignore[method-assign] self._exception_list = exception_list if exception_list else [] def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: @@ -304,7 +304,7 @@ def __call__(self, ep_or_gm): ret = _EXIREdgeDialectVerifier if not class_only: - ret = ret() + ret = ret() # type: ignore[assignment] return ret