Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/_passes/_debug_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.devtools.visualization.visualization_utils import visualize_graph
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass, PassResult


class VisualizePass(ExportPass):
class VisualizePass(ArmPass):
"""
This pass visualizes the graph at the point of insertion in the pass manager
"""
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/add_bias_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import create_constant_placeholder
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -26,6 +27,10 @@ class AddBiasPass(ArmPass):

targeted_ops = (exir_ops.edge.aten.convolution.default,)

def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.exported_program = exported_program

def call(self, graph_module):
modified = False
for node in graph_module.graph.nodes:
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import cast, List, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
Expand All @@ -23,7 +24,7 @@
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
class AnnotateDecomposedMatmulPass(ArmPass):
"""
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
dq -> matmul -> q can become
Expand Down
5 changes: 0 additions & 5 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@
from abc import abstractmethod
from typing import List, Optional, Set, Type

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata


class ArmPass(ExportPass):
"""Base class for Arm passes"""

def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a step in the right direction.

Just a word of caution for the passes which do use exported_program mainly with an intention to add modes which ends up updating the graph signature i.e. using utils like, create_constant_placeholder is that if the graph inside the exported program and the graph the pass got are out of sync then all sorts of weird errors can happen. So it might be wise to assert something like id(exported_progra.graph_module) == if(pass_arg_graph_module)

super(ArmPass, self).__init__()
self.exported_program = exported_program

@property
@abstractmethod
def _passes_required_after(self) -> Set[Type[ExportPass]]:
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
# needs to happen before AddBiasPass, but after the table ops are inserted
# to be able to validate that conv2d has right dtype arguments.
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
self.add_pass(RewriteUpsamplePass(exported_program))
self.add_pass(RewriteUpsamplePass())
self.add_pass(AddBiasPass(exported_program))

self.add_pass(RewriteMatmulPass(exported_program))
self.add_pass(RewriteMatmulPass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
Expand Down Expand Up @@ -298,10 +298,10 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(RewriteUpsamplePass(exported_program))
self.add_pass(RewriteUpsamplePass())
self.add_pass(AddBiasPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(RewriteMatmulPass(exported_program))
self.add_pass(RewriteMatmulPass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/cast_bool_to_int8_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class CastBoolToInt8Pass(ExportPass):
class CastBoolToInt8Pass(ArmPass):
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""

_passes_required_after: Set[Type[ExportPass]] = set()
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import is_buffer
from torch.export import ExportedProgram

logger = logging.getLogger(__name__)


class CastInt64BuffersToInt32Pass(ExportPass):
class CastInt64BuffersToInt32Pass(ArmPass):
"""
Cast int64 buffers to int32 if the int64 data is in int32 range.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64BuffersToInt32Pass, self).__init__()
def __init__(self, exported_program: ExportedProgram):
super().__init__()
self.exported_program = exported_program

def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/cast_to_int32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class CastToInt32Pass(ExportPass):
class CastToInt32Pass(ArmPass):
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""

_passes_required_after: Set[Type[ExportPass]] = set()
Expand Down
24 changes: 19 additions & 5 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@

from typing import Set, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.add_bias_pass import AddBiasPass
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class Conv1dUnsqueezePass(ExportPass):
class Conv1dUnsqueezePass(ArmPass):
"""
This pass is used to change conv1d ops into conv2d since TOSA only
supports 2d and 3d convolution. This is done by modifying the graph to do the
Expand All @@ -38,7 +40,11 @@ def call_operator(self, op, args, kwargs, meta):
x = args[0]
x_unsqueezed_shape = list(x.data.shape) + [1]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta
exir_ops.edge.aten.view_copy.default,
(x, x_unsqueezed_shape),
{},
meta,
updated=True,
)

w_meta = meta.copy()
Expand All @@ -48,7 +54,11 @@ def call_operator(self, op, args, kwargs, meta):
w = args[1]
w_unsqueezed_shape = list(w.data.shape) + [1]
w = super().call_operator(
exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta
exir_ops.edge.aten.view_copy.default,
(w, w_unsqueezed_shape),
{},
w_meta,
updated=True,
)

new_args = (
Expand All @@ -63,12 +73,16 @@ def call_operator(self, op, args, kwargs, meta):
args[8],
)
x = super().call_operator(
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True
)

x_squeezed_shape = list(x.data.shape)[:-1]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta
exir_ops.edge.aten.view_copy.default,
(x, x_squeezed_shape),
{},
meta,
updated=True,
)

return x
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_any_default_dim_dims_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
)
Expand All @@ -18,7 +19,7 @@
)


class ConvertAnyDefaultDimDimsPass(ExportPass):
class ConvertAnyDefaultDimDimsPass(ArmPass):
"""
Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction.
Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion.
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/_passes/convert_elu_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertELUParamsPass(ExportPass):
class ConvertELUParamsPass(ArmPass):
"""
Pass to convert the input_scale kwarg of ELU operator from float to
int.
Expand All @@ -18,6 +21,8 @@ class ConvertELUParamsPass(ExportPass):
the value of input_scale is, as long as that value is not 1.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
graph = graph_module.graph
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
)
Expand Down Expand Up @@ -48,7 +49,7 @@ def calculate_multiples(args):
return multiples


class ConvertExpandCopyToRepeatPass(ExportPass):
class ConvertExpandCopyToRepeatPass(ArmPass):
"""
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
"""
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -19,7 +20,7 @@
INT32_MAX = torch.iinfo(torch.int32).max


class ConvertInt64ConstOpsToInt32Pass(ExportPass):
class ConvertInt64ConstOpsToInt32Pass(ArmPass):
"""
Rewrite constant ops that produce int64 to int32 where safe.

Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_int64_output_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
Expand All @@ -22,7 +23,7 @@
logger = logging.getLogger(__name__)


class ConvertInt64OutputOpsToInt32Pass(ExportPass):
class ConvertInt64OutputOpsToInt32Pass(ArmPass):
"""
Rewrites or removes operations that produce int64 outputs, converting them
to int32 where possible.
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_minmax_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
Expand All @@ -14,7 +15,7 @@
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertMinMaxPass(ExportPass):
class ConvertMinMaxPass(ArmPass):
"""
Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be
TOSA compliant.
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Set, Type

import torch.fx
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
Expand All @@ -16,7 +17,7 @@
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertSplitToSlicePass(ExportPass):
class ConvertSplitToSlicePass(ArmPass):
"""
Replace a split operation with many slice operations.
"""
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/convert_squeezes_to_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from typing import Set, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class ConvertSqueezesToViewPass(ExportPass):
class ConvertSqueezesToViewPass(ArmPass):
"""
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
"""
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/convert_to_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Set, Tuple, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
QuantizeOperatorArguments,
)
Expand All @@ -27,7 +29,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
raise ValueError(f"Getting clamp parameters for op {op} is not implemented.")


class ConvertToClampPass(ExportPass):
class ConvertToClampPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments}

def call_operator(self, op, args, kwargs, meta):
Expand All @@ -39,4 +41,5 @@ def call_operator(self, op, args, kwargs, meta):
(args[0], *get_clamp_params(op, args)),
{},
meta,
updated=True,
)
Loading
Loading