diff --git a/backends/arm/_passes/_debug_passes.py b/backends/arm/_passes/_debug_passes.py index 4c1661e50a9..e22c8a6cf2c 100644 --- a/backends/arm/_passes/_debug_passes.py +++ b/backends/arm/_passes/_debug_passes.py @@ -6,12 +6,13 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.devtools.visualization.visualization_utils import visualize_graph from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult -class VisualizePass(ExportPass): +class VisualizePass(ArmPass): """ This pass visualizes the graph at the point of insertion in the pass manager """ diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index fd5476f51b8..2114d56ef5b 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -10,6 +10,7 @@ from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,6 +27,10 @@ class AddBiasPass(ArmPass): targeted_ops = (exir_ops.edge.aten.convolution.default,) + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + def call(self, graph_module): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 72ae46c76c1..666214ec267 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -10,6 +10,7 @@ from typing import cast, List, Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, @@ -23,7 +24,7 @@ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -class AnnotateDecomposedMatmulPass(ExportPass): +class AnnotateDecomposedMatmulPass(ArmPass): """ torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance: dq -> matmul -> q can become diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index c76b5d157a7..3cc5e3ee0c0 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -9,17 +9,12 @@ from abc import abstractmethod from typing import List, Optional, Set, Type -import torch from executorch.exir.pass_base import ExportPass, NodeMetadata class ArmPass(ExportPass): """Base class for Arm passes""" - def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None): - super(ArmPass, self).__init__() - self.exported_program = exported_program - @property @abstractmethod def _passes_required_after(self) -> Set[Type[ExportPass]]: diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 325f667f0ac..d0d3aae148f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -210,10 +210,10 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: # needs to happen before AddBiasPass, but after the table ops are inserted # to be able to validate that conv2d has right dtype arguments. self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(RewriteUpsamplePass(exported_program)) + self.add_pass(RewriteUpsamplePass()) self.add_pass(AddBiasPass(exported_program)) - self.add_pass(RewriteMatmulPass(exported_program)) + self.add_pass(RewriteMatmulPass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) @@ -298,10 +298,10 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(RewriteUpsamplePass(exported_program)) + self.add_pass(RewriteUpsamplePass()) self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) - self.add_pass(RewriteMatmulPass(exported_program)) + self.add_pass(RewriteMatmulPass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py index 771b6d9e174..0987476a2ec 100644 --- a/backends/arm/_passes/cast_bool_to_int8_pass.py +++ b/backends/arm/_passes/cast_bool_to_int8_pass.py @@ -10,11 +10,12 @@ import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class CastBoolToInt8Pass(ExportPass): +class CastBoolToInt8Pass(ArmPass): """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index d7b2a6b6b43..33d07f54af0 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -9,21 +9,23 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer +from torch.export import ExportedProgram logger = logging.getLogger(__name__) -class CastInt64BuffersToInt32Pass(ExportPass): +class CastInt64BuffersToInt32Pass(ArmPass): """ Cast int64 buffers to int32 if the int64 data is in int32 range. """ _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self, exported_program: torch.export.ExportedProgram): - super(CastInt64BuffersToInt32Pass, self).__init__() + def __init__(self, exported_program: ExportedProgram): + super().__init__() self.exported_program = exported_program def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node): diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index 2e574568235..db626bf5695 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -7,11 +7,12 @@ import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class CastToInt32Pass(ExportPass): +class CastToInt32Pass(ArmPass): """Casts the input to int32 if it is not already and casts back the output to the original input dtype.""" _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index b228da6766f..7784c850278 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -8,6 +8,8 @@ from typing import Set, Type +from executorch.backends.arm._passes import ArmPass + from executorch.backends.arm._passes.add_bias_pass import AddBiasPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass @@ -15,7 +17,7 @@ from executorch.exir.pass_base import ExportPass -class Conv1dUnsqueezePass(ExportPass): +class Conv1dUnsqueezePass(ArmPass): """ This pass is used to change conv1d ops into conv2d since TOSA only supports 2d and 3d convolution. This is done by modifying the graph to do the @@ -38,7 +40,11 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] x_unsqueezed_shape = list(x.data.shape) + [1] x = super().call_operator( - exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta + exir_ops.edge.aten.view_copy.default, + (x, x_unsqueezed_shape), + {}, + meta, + updated=True, ) w_meta = meta.copy() @@ -48,7 +54,11 @@ def call_operator(self, op, args, kwargs, meta): w = args[1] w_unsqueezed_shape = list(w.data.shape) + [1] w = super().call_operator( - exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta + exir_ops.edge.aten.view_copy.default, + (w, w_unsqueezed_shape), + {}, + w_meta, + updated=True, ) new_args = ( @@ -63,12 +73,16 @@ def call_operator(self, op, args, kwargs, meta): args[8], ) x = super().call_operator( - exir_ops.edge.aten.convolution.default, new_args, kwargs, meta + exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True ) x_squeezed_shape = list(x.data.shape)[:-1] x = super().call_operator( - exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta + exir_ops.edge.aten.view_copy.default, + (x, x_squeezed_shape), + {}, + meta, + updated=True, ) return x diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py index 8c8e5086b6d..d09cd22cbd4 100644 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ b/backends/arm/_passes/convert_any_default_dim_dims_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_squeezes_to_view import ( ConvertSqueezesToViewPass, ) @@ -18,7 +19,7 @@ ) -class ConvertAnyDefaultDimDimsPass(ExportPass): +class ConvertAnyDefaultDimDimsPass(ArmPass): """ Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction. Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion. diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 7da58ae4bb4..86c1c52c5b7 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -3,13 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -class ConvertELUParamsPass(ExportPass): +class ConvertELUParamsPass(ArmPass): """ Pass to convert the input_scale kwarg of ELU operator from float to int. @@ -18,6 +21,8 @@ class ConvertELUParamsPass(ExportPass): the value of input_scale is, as long as that value is not 1. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 83b47d31755..7f66a4343b9 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -10,6 +10,7 @@ import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( UnsqueezeBeforeRepeatPass, ) @@ -48,7 +49,7 @@ def calculate_multiples(args): return multiples -class ConvertExpandCopyToRepeatPass(ExportPass): +class ConvertExpandCopyToRepeatPass(ArmPass): """ Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index 2bf305a13f6..798bbc6006f 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -10,6 +10,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.exir.pass_base import ExportPass, PassResult @@ -19,7 +20,7 @@ INT32_MAX = torch.iinfo(torch.int32).max -class ConvertInt64ConstOpsToInt32Pass(ExportPass): +class ConvertInt64ConstOpsToInt32Pass(ArmPass): """ Rewrite constant ops that produce int64 to int32 where safe. diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index d0d29d14e30..7eb02493d50 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -10,6 +10,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) -class ConvertInt64OutputOpsToInt32Pass(ExportPass): +class ConvertInt64OutputOpsToInt32Pass(ArmPass): """ Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible. diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 79bb6e2db0c..34fcefa20e3 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -6,6 +6,7 @@ from typing import cast, Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.convert_squeezes_to_view import ( ConvertSqueezesToViewPass, @@ -14,7 +15,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class ConvertMinMaxPass(ExportPass): +class ConvertMinMaxPass(ArmPass): """ Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be TOSA compliant. diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 2cce0315c12..cd9f8bef2f7 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -8,6 +8,7 @@ from typing import Set, Type import torch.fx +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -16,7 +17,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class ConvertSplitToSlicePass(ExportPass): +class ConvertSplitToSlicePass(ArmPass): """ Replace a split operation with many slice operations. """ diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 70f4625f0ff..c7d02c27a36 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -8,13 +8,15 @@ from typing import Set, Type +from executorch.backends.arm._passes import ArmPass + from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class ConvertSqueezesToViewPass(ExportPass): +class ConvertSqueezesToViewPass(ArmPass): """ Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py index 0199d6798bc..1ada1efe69b 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp.py @@ -5,6 +5,8 @@ from typing import Set, Tuple, Type +from executorch.backends.arm._passes import ArmPass + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( QuantizeOperatorArguments, ) @@ -27,7 +29,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: raise ValueError(f"Getting clamp parameters for op {op} is not implemented.") -class ConvertToClampPass(ExportPass): +class ConvertToClampPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments} def call_operator(self, op, args, kwargs, meta): @@ -39,4 +41,5 @@ def call_operator(self, op, args, kwargs, meta): (args[0], *get_clamp_params(op, args)), {}, meta, + updated=True, ) diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py index bbb8ceba129..0187ee45a1e 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -7,6 +7,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, @@ -36,7 +37,7 @@ def get_decomposition(op) -> tuple: raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") -class DecomposeAvgPool2d(ExportPass): +class DecomposeAvgPool2d(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} def call_operator(self, op, args, kwargs, meta): @@ -69,7 +70,9 @@ def call_operator(self, op, args, kwargs, meta): # Add width padding manually if count_include_pad if count_include_pad and pad_w > 0: pre_pad_shape = [n, c, h, pad_w] - pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + pre_pad = super().call_operator( + full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + ) if ceil_mode and divisor_override is None: post_pad_w = pad_w @@ -81,19 +84,23 @@ def call_operator(self, op, args, kwargs, meta): if post_pad_w > 0: post_pad_shape = [n, c, h, post_pad_w] post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta + full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True ) cat_nodes = [pre_pad, x, post_pad] else: cat_nodes = [pre_pad, x] - x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta) + x = super().call_operator( + cat_op, (cat_nodes, 3), kwargs, meta, updated=True + ) new_pad_w = 0 # Add height padding manually if count_include_pad if count_include_pad and pad_h > 0: pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] - pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + pre_pad = super().call_operator( + full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + ) if ceil_mode and divisor_override is None: post_pad_h = pad_h @@ -105,13 +112,15 @@ def call_operator(self, op, args, kwargs, meta): if post_pad_h > 0: post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta + full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True ) cat_nodes = [pre_pad, x, post_pad] else: cat_nodes = [pre_pad, x] - x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta) + x = super().call_operator( + cat_op, (cat_nodes, 2), kwargs, meta, updated=True + ) new_pad_h = 0 avgpool_args = ( @@ -122,13 +131,19 @@ def call_operator(self, op, args, kwargs, meta): ceil_mode, False, ) - x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta) + x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta, updated=True) # Multiply by factor (kernel_size / divisor_override) if divisor_override if divisor_override is not None and divisor_override != kernel_size: override_multiplier = super().call_operator( - full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta + full_op, + ([1, 1, 1, 1], kernel_size / divisor_override), + kwargs, + meta, + updated=True, + ) + x = super().call_operator( + mul_op, (x, override_multiplier), kwargs, meta, updated=True ) - x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta) return x diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 965dad54697..96a95ee2a1c 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) @@ -18,7 +19,7 @@ torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) -class DecomposeCosineSimilarityPass(ExportPass): +class DecomposeCosineSimilarityPass(ArmPass): """ Decomposition of aten.cosine_similarity: diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 32c59f6d793..2111c654817 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -13,6 +13,7 @@ from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -43,6 +44,10 @@ class DecomposeCumsumPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {AddBiasPass} + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + def call(self, graph_module): graph = graph_module.graph targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default) diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index b6db103930e..f2ae77514c5 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -9,6 +9,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -29,7 +30,7 @@ def get_div_decomposition(op) -> tuple: raise RuntimeError(f"Can't get div decomposition for op {op}") -class DecomposeDivPass(ExportPass): +class DecomposeDivPass(ArmPass): """ This pass decomposes div into a mul and a reciprocal node. @@ -50,6 +51,10 @@ def call_operator(self, op, args, kwargs, meta): numerator = args[0] denominator = args[1] - reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) + reciprocal = super().call_operator( + reciprocal_op, (denominator,), {}, meta, updated=True + ) - return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) + return super().call_operator( + mul_op, (numerator, reciprocal), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index b5352475d51..07e57c60f1b 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -8,6 +8,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -42,7 +43,7 @@ def _get_opset(op): raise RuntimeError(f"div.Tensor_mode not supported for op {op}") -class DecomposeDivTensorModePass(ExportPass): +class DecomposeDivTensorModePass(ArmPass): """ Rewrites aten.div.Tensor_mode into @@ -64,13 +65,13 @@ def call_operator(self, op, args, kwargs, meta): if rounding_mode is None and len(args) > 2: rounding_mode = args[2] - q = super().call_operator(opset["div"], (a, b), {}, meta) + q = super().call_operator(opset["div"], (a, b), {}, meta, updated=True) if rounding_mode is None: return q if rounding_mode == "floor": - return super().call_operator(opset["floor"], (q,), {}, meta) + return super().call_operator(opset["floor"], (q,), {}, meta, updated=True) if rounding_mode == "trunc": zero = super().call_operator( @@ -78,11 +79,14 @@ def call_operator(self, op, args, kwargs, meta): args=((1,) * len(meta["val"].size()), 0.0), kwargs={"dtype": torch.float32}, meta=meta, + updated=True, + ) + lt0 = super().call_operator(opset["lt"], (q, zero), {}, meta, updated=True) + ceilq = super().call_operator(opset["ceil"], (q,), {}, meta, updated=True) + floorq = super().call_operator(opset["floor"], (q,), {}, meta, updated=True) + return super().call_operator( + opset["where"], (lt0, ceilq, floorq), {}, meta, updated=True ) - lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) - ceilq = self.call_operator(opset["ceil"], (q,), {}, meta) - floorq = self.call_operator(opset["floor"], (q,), {}, meta) - return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta) raise RuntimeError( f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}" diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 01226a7a38e..ac424230491 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -11,6 +11,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -21,7 +22,7 @@ logger.setLevel(logging.WARNING) -class DecomposeEmbeddingPass(ExportPass): +class DecomposeEmbeddingPass(ArmPass): """ This pass decomposes embedding into index_select. diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 237b8199e82..532f5d859fe 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -43,7 +44,7 @@ def _get_gelu_ops(op) -> tuple: raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}") -class DecomposeGeluPass(ExportPass): +class DecomposeGeluPass(ArmPass): """ This pass decomposes the GELU operator into primitive ops. Aiming to adhere closely to the reference implementations built into diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py index 916e43ee9a4..11d9f605127 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv.py @@ -7,13 +7,14 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class DecomposeGroupedConv(ExportPass): +class DecomposeGroupedConv(ArmPass): """ Splits a grouped convolution which is not supported by TOSA into multiple convolutions using slice->conv->cat. @@ -125,7 +126,9 @@ def call_operator(self, op, args, kwargs, meta): slice_args = (input_node, 1, start_index, stop_index) input_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) ) filter_slices = [] @@ -135,7 +138,9 @@ def call_operator(self, op, args, kwargs, meta): slice_args = (weight_node, 0, start_index, stop_index) filter_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) ) bias_slices = [] @@ -148,7 +153,9 @@ def call_operator(self, op, args, kwargs, meta): slice_args = (bias_node, 0, start_index, stop_index) bias_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) ) output_slices = [] @@ -166,9 +173,11 @@ def call_operator(self, op, args, kwargs, meta): raise RuntimeError("Invalid op for grouped conv decomposition") output_slices.append( - super().call_operator(conv_op, conv_args, kwargs, meta_copy) + super().call_operator( + conv_op, conv_args, kwargs, meta_copy, updated=True + ) ) cat_args = (output_slices, 1) # propagate original metadata (including quantization params) to the concatenated output - return super().call_operator(cat_op, cat_args, kwargs, meta) + return super().call_operator(cat_op, cat_args, kwargs, meta, updated=True) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index ea5dd2d9b55..5c6c8fc0ec5 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -6,12 +6,13 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.exir.pass_base import ExportPass -class DecomposeLinearVectorNormPass(ExportPass): +class DecomposeLinearVectorNormPass(ArmPass): """ This pass decomposes aten.linalg_vector_norm.default into more primitive ops. We need to add this pass before quantization for graph annotation. diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 049409af6fd..73f8decf4a1 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -9,6 +9,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -20,7 +21,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class DecomposeSelectPass(ExportPass): +class DecomposeSelectPass(ArmPass): """ This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1) """ diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index 3d31552cf35..413beb2625f 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -8,13 +8,14 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.pass_base import ExportPass aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default) -class DecomposeSiluPass(ExportPass): +class DecomposeSiluPass(ArmPass): """ This pass decomposes silu into a mul and a sigmoid node. @@ -34,6 +35,8 @@ def call_operator(self, op, args, kwargs, meta): mul_op = torch.ops.aten.mul.Tensor original = args[0] - sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta) + sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta, updated=True) - return super().call_operator(mul_op, (original, sigmoid), {}, meta) + return super().call_operator( + mul_op, (original, sigmoid), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index 52df7cf6700..ee841e54f26 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops @@ -53,7 +54,7 @@ def _get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get logsoftmax decomposition ops for op {op}") -class DecomposeSoftmaxPass(ExportPass): +class DecomposeSoftmaxPass(ArmPass): """ This pass decomposes log_softmax or softmax into more primitive ops. Example: @@ -79,12 +80,12 @@ def call_operator(self, op, args, kwargs, meta): ) _input = args[0] dim = [args[1]] - op1 = super().call_operator(max_op, (_input, dim, True), {}, meta) - op2 = super().call_operator(sub_op, (_input, op1), {}, meta) - op3 = super().call_operator(exp_op, (op2,), {}, meta) - op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta) - op5 = super().call_operator(reciprocal_op, (op4,), {}, meta) - op6 = super().call_operator(mul_op, (op3, op5), {}, meta) + op1 = super().call_operator(max_op, (_input, dim, True), {}, meta, updated=True) + op2 = super().call_operator(sub_op, (_input, op1), {}, meta, updated=True) + op3 = super().call_operator(exp_op, (op2,), {}, meta, updated=True) + op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta, updated=True) + op5 = super().call_operator(reciprocal_op, (op4,), {}, meta, updated=True) + op6 = super().call_operator(mul_op, (op3, op5), {}, meta, updated=True) if op in log_softmax: - op6 = super().call_operator(log_op, (op6,), {}, meta) + op6 = super().call_operator(log_op, (op6,), {}, meta, updated=True) return op6 diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 3f4e608c4b9..e716de3b048 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -7,6 +7,7 @@ from typing import Set, Tuple, Type, Union import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -27,7 +28,7 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: raise RuntimeError(f"Can't get sqrt decomposition for op {op}") -class DecomposeSqrtPass(ExportPass): +class DecomposeSqrtPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): @@ -40,4 +41,4 @@ def call_operator(self, op, args, kwargs, meta): pow_op = get_sqrt_decomposition(op) - return super().call_operator(pow_op, (args[0], 0.5), {}, meta) + return super().call_operator(pow_op, (args[0], 0.5), {}, meta, updated=True) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 16027ccec2b..59c352a0e07 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +24,7 @@ def _get_sum_decomp(op): raise RuntimeError("Unvalid op in DecomposeSumPass") -class DecomposeSumPass(ExportPass): +class DecomposeSumPass(ArmPass): """ In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always @@ -76,13 +77,13 @@ def call_operator(self, op, args, kwargs, meta): for dim in dims: input_node = super().call_operator( - sum_op, (input_node, dim, True), kwargs, meta + sum_op, (input_node, dim, True), kwargs, meta, updated=True ) if not keepdims: shape = list(meta["val"].size()) input_node = super().call_operator( - view_op, (input_node, shape), kwargs, meta + view_op, (input_node, shape), kwargs, meta, updated=True ) return input_node diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 477e007b8bf..4427e0357a0 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -8,7 +8,7 @@ import copy -from typing import cast, Dict, Set, Tuple, Type +from typing import cast, Optional, Set, Type from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -20,17 +20,12 @@ from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ( - Argument, - ExportPass, - NodeMetadata, - PassResult, - ProxyValue, -) +from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -72,7 +67,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: return output_qparams -class RetraceFoldedDtypesPass(ExportPass): +class RetraceFoldedDtypesPass(ArmPass): """ FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced some operators are retraced to types that cannot be handled by TOSA. One @@ -90,24 +85,18 @@ class RetraceFoldedDtypesPass(ExportPass): exir_ops.edge.aten.sum.dim_IntList, } - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: + def call_operator(self, op, args, kwargs, meta): if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta) + return super().call_operator(op, args, kwargs, meta, False) node_kwargs = kwargs.copy() output_qparams = meta["output_qparams"] if len(output_qparams) == 0: - return super().call_operator(op, args, kwargs, meta) + return super().call_operator(op, args, kwargs, meta, False) output_dtype = output_qparams[0].dtype node_kwargs["dtype"] = output_dtype - return super().call_operator(op, args, node_kwargs, meta) + return super().call_operator(op, args, node_kwargs, meta, True) class FoldAndAnnotateQParamsPass(ArmPass): @@ -146,6 +135,10 @@ class FoldAndAnnotateQParamsPass(ArmPass): RemoveNoopPass, } + def __init__(self, exported_program: Optional[ExportedProgram] = None) -> None: + super().__init__() + self.exported_program = exported_program + def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int ) -> None: @@ -249,7 +242,7 @@ def call(self, graph_module: GraphModule) -> PassResult: return PassResult(graph_module, True) -class QuantizeOperatorArguments(ExportPass): +class QuantizeOperatorArguments(ArmPass): """ This pass makes sure that the arguments to clamp.default are quantized correctly. More specifically, this pass: diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 8be6b61d25c..5d4308ec3f6 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -8,6 +8,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -26,7 +27,7 @@ from torch.nn.utils.fusion import fuse_conv_bn_weights -class FuseBatchnorm2DPass(ExportPass): +class FuseBatchnorm2DPass(ArmPass): """Fuses the pattern convolution -> batchnorm by updating the weights and bias of the convolution and removing the batchnorm. """ @@ -34,8 +35,8 @@ class FuseBatchnorm2DPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() def __init__(self, exported_program: ExportedProgram): - self.exported_program = exported_program super().__init__() + self.exported_program = exported_program def get_bias_name(self, weight_node: Node, bias_node: Node | None) -> str: if bias_node: diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index c48fc008b5d..efc140889d6 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -8,6 +8,7 @@ import torch._export.utils import torch.fx +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_first_fake_tensor, @@ -29,7 +30,7 @@ logger = logging.getLogger(__name__) -class FuseConstantArgsPass(ExportPass): +class FuseConstantArgsPass(ArmPass): """ Fuses ops with only placeholder parameters into one placeholder parameter node with the op pre-calulcated on its data. @@ -162,7 +163,7 @@ def call(self, graph_module): return PassResult(graph_module, True) -class ComputeConstantOpsAOT(ExportPass): +class ComputeConstantOpsAOT(ArmPass): """ Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders. diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index b8b8143e6c5..ed558e2cb4b 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -9,6 +9,7 @@ import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_param_tensor, @@ -23,7 +24,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class FuseEqualPlaceholdersPass(ExportPass): +class FuseEqualPlaceholdersPass(ArmPass): """ This pass optimizes memory usage by finding constant placeholders pointing to identical tensors and fusing them to one single placeholder @@ -33,8 +34,8 @@ class FuseEqualPlaceholdersPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() def __init__(self, exported_program: ExportedProgram): - self.exported_program = exported_program super().__init__() + self.exported_program = exported_program def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 1076a3df658..46a1c0d66fe 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -8,6 +8,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, @@ -20,7 +21,7 @@ from torch.fx import Node -class FuseQuantizedActivationPass(ExportPass): +class FuseQuantizedActivationPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { ConvertToClampPass, FoldAndAnnotateQParamsPass, diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index c6e6f70a630..a12388e65df 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -11,6 +11,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.decompose_embedding_pass import ( DecomposeEmbeddingPass, @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) -class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass): +class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass): """ Insert an int64->int32 cast after each int64 placeholder. diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index d56e70e78b3..d43ddfd7391 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -12,6 +12,7 @@ from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) + from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -19,7 +20,7 @@ from torch.fx import GraphModule, Node -class InsertRescalePass(ExportPass): +class InsertRescalePass(ArmPass): """Finds patterns of dq -> q, and replaces them with backend dialect tosa::RESCALE op. diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index d838ddc823d..e77d0c64c71 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -9,6 +9,7 @@ from typing import Callable, cast, Dict, Iterator, Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.transforms.utils import create_constant_placeholder @@ -109,7 +110,7 @@ def included_ops() -> Iterator[EdgeOpOverload]: return chain(TableOps.unary_table_ops, TableOps.special_table_ops) -class InsertTableOpsPass(ExportPass): +class InsertTableOpsPass(ArmPass): """ For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index d482614b03f..f0aaa0cf5f9 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -6,6 +6,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -28,7 +29,7 @@ def get_largest_dtype(dtype_1, dtype_2): return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2 -class MatchArgDtypePass(ExportPass): +class MatchArgDtypePass(ArmPass): """Pass to match data types of non-condition input tensors. Edge dialect allows different data types for non-condition tensors, while TOSA diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index c411f3b8083..e70e45c61b4 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -9,10 +9,13 @@ from typing import cast, Set, Type +from executorch.backends.arm._passes import ArmPass + from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, ) +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -20,7 +23,7 @@ from torch.fx import GraphModule, Node -class MatchArgRanksPass(ExportPass): +class MatchArgRanksPass(ArmPass): """ For ops in 'targeted_ops', make sure that the inputs share the same rank. New dimensions are inserted from the beginning of the inputs that have a @@ -38,7 +41,7 @@ class MatchArgRanksPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self, exported_program): + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index c6f4786365d..353977fba0a 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -9,6 +9,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -26,7 +27,7 @@ from torch.fx import Node -class ConvertMmToBmmPass(ExportPass): +class ConvertMmToBmmPass(ArmPass): """ This pass converts a MM node to a BMM one and turns input and output tensors from rank 2 to rank 3. The TOSA specification requires rank 3. The graph is diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 55c4f71f0a8..5035e26bc47 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -9,13 +9,15 @@ import logging from typing import Set, Type +from executorch.backends.arm._passes import ArmPass + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass logger = logging.getLogger(__name__) -class RemoveNoopPass(ExportPass): +class RemoveNoopPass(ArmPass): """Remove no-ops from graph_module""" _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py index 506030d82d7..7a42d08dd61 100644 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -10,19 +10,17 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.pass_base import ExportPass, PassResult -class ReplaceInfValues(ExportPass): +class ReplaceInfValues(ArmPass): """ Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. """ _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self): - super(ReplaceInfValues, self).__init__() - def call(self, graph_module: torch.fx.GraphModule): modified = False for buf_name, tensor in graph_module.named_buffers(): diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 9ad3e318011..5ca6a60e844 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -9,6 +9,7 @@ from typing import cast, Set, Type, Union import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -17,7 +18,7 @@ from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix -class ScalarsToAttributePass(ExportPass): +class ScalarsToAttributePass(ArmPass): """ For ops in 'targeted_ops', convert inputs that are scalar values to attribute Nodes that output the same value. diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 5eb77dc56df..c82bcab947c 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -8,6 +8,7 @@ from typing import cast, Set, Type, TypeAlias import torch.fx +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -137,7 +138,7 @@ def is_valid_operator(node: torch.fx.Node) -> bool: return False -class SizeAdjustInputPass(ExportPass): +class SizeAdjustInputPass(ArmPass): """ Adjusts the input size to Conv2D and Pooling operators. PyTorch allows the input and kernel shape to not "match", in which case the remaining diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index b906c06b329..75646ce4379 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -10,6 +10,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.annotate_decomposed_matmul import ( AnnotateDecomposedMatmulPass, ) @@ -44,7 +45,7 @@ def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: return node.op == "placeholder" and not is_param_node(exported_program, node) -class ToTosaMemoryFormatPass(ExportPass): +class ToTosaMemoryFormatPass(ArmPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE @@ -55,8 +56,8 @@ class ToTosaMemoryFormatPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() def __init__(self, exported_program: ExportedProgram) -> None: - self.exported_program = exported_program super().__init__() + self.exported_program = exported_program @staticmethod def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node): diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py index 66286b6a954..6384f001580 100644 --- a/backends/arm/_passes/unsqueeze_before_repeat_pass.py +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -8,6 +8,7 @@ import torch import torch.fx +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -16,7 +17,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class UnsqueezeBeforeRepeatPass(ExportPass): +class UnsqueezeBeforeRepeatPass(ArmPass): """ A TOSA TILE op only supports rank(in) == rank(out). To support Pytorch's repeat which can also add dimensions, diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index d3932dd1217..5691b04ff2f 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -8,11 +8,13 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer, is_param -class UnsqueezeScalarPlaceholdersPass(ExportPass): +class UnsqueezeScalarPlaceholdersPass(ArmPass): """ Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. This pass unsqueezes the placeholders to make sure shape is at least (1,). @@ -20,9 +22,9 @@ class UnsqueezeScalarPlaceholdersPass(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self, exported_program): - self.exported_program = exported_program + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() + self.exported_program = exported_program def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index 643a3bf5733..aed87c05799 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -179,11 +179,8 @@ def test_to_tosa_memory_format_tosa_INT(module): module.get_inputs(), ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - pass_list=[RemoveGetItemPass], - passes_with_exported_program=[ - AnnotateOutputDimOrderPass, - ToTosaMemoryFormatPass, - ], + pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], + passes_with_exported_program=[ToTosaMemoryFormatPass], ) pipeline.pop_stage( "run_method_and_compare_outputs"