diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index 6ca59cfee27..843d6b159dc 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -7,6 +7,7 @@ python_library( deps = [ "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", + "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", ], diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 78a78bbda30..2dd3c4dc49d 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -256,6 +256,7 @@ python_library( "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:remove_ops", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index d0166061c7f..3d73e7f8c1e 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Callable, List, Optional, Set, Union +from typing import Callable, List, Optional, Set, Type, Union import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -32,33 +32,33 @@ class CadencePassAttribute: # A dictionary that maps an ExportPass to its attributes. -ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} +ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {} -def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: +def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute: return ALL_CADENCE_PASSES[p] # A decorator that registers a pass. def register_cadence_pass( pass_attribute: CadencePassAttribute, -) -> Callable[[ExportPass], ExportPass]: - def wrapper(cls: ExportPass) -> ExportPass: +) -> Callable[[Type[ExportPass]], Type[ExportPass]]: + def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]: ALL_CADENCE_PASSES[cls] = pass_attribute return cls return wrapper -def get_all_available_cadence_passes() -> Set[ExportPass]: +def get_all_available_cadence_passes() -> Set[Type[ExportPass]]: return set(ALL_CADENCE_PASSES.keys()) # Create a new filter to filter out relevant passes from all passes. def create_cadence_pass_filter( opt_level: int, debug: bool = False -) -> Callable[[ExportPass], bool]: - def _filter(p: ExportPass) -> bool: +) -> Callable[[Type[ExportPass]], bool]: + def _filter(p: Type[ExportPass]) -> bool: pass_attribute = get_cadence_pass_attribute(p) return ( pass_attribute.opt_level is not None diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index ab23149e60d..4e27f83c13e 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, List, Optional, Type +from typing import Any, cast, List, Optional, Type import torch import torch.fx @@ -95,9 +95,9 @@ def get_cadence_passes( passes = get_passes_in_default_order() pass_filter = create_cadence_pass_filter(opt_level) filtered_passes = [ - # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`. filtered_pass() # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`. for filtered_pass in list(filter(pass_filter, passes)) ] - return filtered_passes + # The type checker can't infer the proper type of the list comprehension. + return cast(List[Optional[PassResult]], filtered_passes) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 120f69008c1..f91fb26ddc8 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1719,9 +1719,9 @@ def call_operator(self, op, args, kwargs, meta): ) -@register_cadence_pass(CadencePassAttribute(opt_level=0))( - ReplaceScalarWithTensorArgPass() -) +register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) + + @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceScalarTensorWithFullPass(ExportPass): """ diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index c532798546d..ec4e1412862 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -201,6 +201,20 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "replace_scalar_with_tensor", + srcs = [ + "replace_scalar_with_tensor.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [