From 36ae21ff1b61b6ea7e8672ec091f3bdc1551bb3b Mon Sep 17 00:00:00 2001 From: Zonglin Peng Date: Wed, 20 Nov 2024 15:46:31 -0800 Subject: [PATCH] add remove ops to oss and callsites, [cadence][8/X] add reorder ops to oss and callsites, [cadence][9/X] add replace ops to oss and callsites, [cadence][10/X] merge passes with replace remove ops passes and update default pass order... (#6993) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/6993 ...in OSS, add replace pass testing, add fusion pass testing, add remove pass testing, add simplify pass testing, add reorder pass testing, [Cadence]remove internal names titled titled titled titled jarvis -> meta Differential Revision: D66264166 --- backends/cadence/README.md | 2 +- backends/cadence/aot/TARGETS | 160 +- backends/cadence/aot/compiler.py | 26 +- backends/cadence/aot/pass_utils.py | 48 +- backends/cadence/aot/passes.py | 332 +-- backends/cadence/aot/remove_ops.py | 731 ++++++ backends/cadence/aot/reorder_ops.py | 824 +++++++ backends/cadence/aot/replace_ops.py | 2111 +++++++++++++++++ .../aot/tests/test_fusion_ops_passes.py | 594 +++++ .../aot/tests/test_remove_ops_passes.py | 674 ++++++ .../aot/tests/test_reorder_ops_passes.py | 355 +++ .../aot/tests/test_replace_ops_passes.py | 1683 +++++++++++++ .../aot/tests/test_simplify_ops_passes.py | 108 + backends/cadence/aot/utils.py | 16 +- backends/cadence/runtime/runtime.py | 18 +- 15 files changed, 7349 insertions(+), 333 deletions(-) create mode 100644 backends/cadence/aot/remove_ops.py create mode 100644 backends/cadence/aot/reorder_ops.py create mode 100644 backends/cadence/aot/replace_ops.py create mode 100644 backends/cadence/aot/tests/test_fusion_ops_passes.py create mode 100644 backends/cadence/aot/tests/test_remove_ops_passes.py create mode 100644 backends/cadence/aot/tests/test_reorder_ops_passes.py create mode 100644 backends/cadence/aot/tests/test_replace_ops_passes.py create mode 100644 backends/cadence/aot/tests/test_simplify_ops_passes.py diff --git a/backends/cadence/README.md b/backends/cadence/README.md index 867dbe31db4..998ac55ddf0 100644 --- a/backends/cadence/README.md +++ b/backends/cadence/README.md @@ -2,7 +2,7 @@ ## Supported DSPs (in progress) - HiFi Audio -- ... +- Fusion G3 ## Tutorial diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index c0374faa7e9..24b02669113 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -39,6 +39,7 @@ python_library( ":passes", ":utils", ":ops_registrations", + ":replace_ops", "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", @@ -74,12 +75,14 @@ python_library( ":utils", ":fuse_ops", ":simplify_ops", + ":replace_ops", + ":reorder_ops", + ":remove_ops", "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/passes:lib", "//executorch/exir/passes:spec_prop_pass", - "//executorch/backends/transforms:remove_clone_ops" ], ) @@ -180,6 +183,63 @@ python_library( ], ) +python_library( + name = "remove_ops", + srcs = [ + "remove_ops.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:simplify_ops", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", + "//executorch/exir/passes:spec_prop_pass", + "//executorch/backends/transforms:remove_clone_ops" + ], +) + +python_library( + name = "reorder_ops", + srcs = [ + "reorder_ops.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler_utils", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:utils", + "//executorch/exir:pass_base", + "//executorch/exir:tensor", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", + ], +) + +python_library( + name = "replace_ops", + srcs = [ + "replace_ops.py", + ], + typing = True, + deps = [ + ":pass_utils", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler_utils", + "//executorch/backends/cadence/aot:fuse_ops", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:remove_ops", + "//executorch/backends/cadence/aot:utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", + "//executorch/exir/passes:spec_prop_pass", + ], +) + python_unittest( name = "test_graph_builder", srcs = [ @@ -196,3 +256,101 @@ python_unittest( ":ops_registrations" ], ) + +python_unittest( + name = "test_replace_ops_passes", + srcs = [ + "tests/test_replace_ops_passes.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + ":compiler", + ":replace_ops", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + ], +) + +python_unittest( + name = "test_fusion_ops_passes", + srcs = [ + "tests/test_fusion_ops_passes.py", + ], + typing = True, + deps = [ + ":compiler", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:fuse_ops", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:ops_registrations", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", + ], +) + +python_unittest( + name = "test_remove_ops_passes", + srcs = [ + "tests/test_remove_ops_passes.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", + ":compiler", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:ops_registrations", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:remove_ops", + "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/exir/dialects:lib", + ], +) + +python_unittest( + name = "test_simplify_ops_passes", + srcs = [ + "tests/test_simplify_ops_passes.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:ops_registrations", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:simplify_ops", + "//executorch/exir/dialects:lib", + ], +) + +python_unittest( + name = "test_reorder_ops_passes", + srcs = [ + "tests/test_reorder_ops_passes.py", + ], + typing = True, + deps = [ + ":compiler", + ":pass_utils", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:fuse_ops", + "//executorch/backends/cadence/aot:ops_registrations", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:reorder_ops", + "//executorch/exir/dialects:lib", + ], +) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index e53826b7b98..3f98745093c 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -12,10 +12,10 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch - -from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer + +from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, @@ -194,9 +194,6 @@ def export_to_edge( return edge_prog_manager -# Export the model and lower it to an EdgeProgramManager (in edge IR), and -# apply passes specific to Cadence DSP execution. Return both to print the -# differences. def export_to_cadence( model: torch.nn.Module, inputs: tuple[object, ...], @@ -216,6 +213,25 @@ def export_to_cadence( return cadence_prog_manager +def quantize_and_export_to_cadence( + model: torch.nn.Module, + inputs: tuple[object, ...], + dump_graphs: bool = False, + opt_level: int = 1, +) -> EdgeProgramManager: + quantized_model = quantize_pt2(model, inputs) + + return export_to_cadence( + quantized_model, + inputs, + opt_level=opt_level, + dump_graphs=dump_graphs, + ) + + +# Export the model and lower it to an EdgeProgramManager (in edge IR), and +# apply passes specific to Cadence DSP execution. Return both to print the +# differences. def export_to_executorch_gen_etrecord( model: torch.nn.Module, inputs: tuple[object, ...], diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index ed56a1b85fb..c17e854af18 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -3,7 +3,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Callable, Optional, Set, Union +from typing import Callable, List, Optional, Set, Union import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -50,7 +50,7 @@ def get_all_available_cadence_passes() -> Set[ExportPass]: return set(ALL_CADENCE_PASSES.keys()) -# Create a new filter to filter out relevant passes from all Jarvis passes. +# Create a new filter to filter out relevant passes from all passes. def create_cadence_pass_filter( opt_level: int, debug: bool = False ) -> Callable[[ExportPass], bool]: @@ -98,3 +98,47 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) if node.op == "call_function" and node.target == target: total += 1 return total + + +# Testing utils +# Return the compute/function nodes in the graph +def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]: + nodes = [] + for x in graph_module.graph.nodes: + if x.op == "call_function": + if isinstance(x.target, torch._ops.OpOverload): + nodes.append(x.target.overloadpacket) + elif isinstance(x.target, EdgeOpOverload): + nodes.append(get_edge_overload_packet(x.target)) + return nodes + + +# Return true if there is no edge from a node with target pred_target to a +# node with target succ_target in the graph. +def nodes_not_connected_in_gm( + graph_module: torch.fx.GraphModule, + pred_target: torch.fx.Node, + succ_target: torch.fx.Node, +) -> bool: + for node in graph_module.graph.nodes: + if node.target != pred_target: + continue + for user in node.users: + if user.target == succ_target: + return False + return True + + +# Returns true if there is no instance of a node with target succ_target +# positioned immediately after a node with target pred_target in the graph +def nodes_not_adjacent_in_gm( + graph_module: torch.fx.GraphModule, + pred_target: torch.fx.Node, + succ_target: torch.fx.Node, +) -> bool: + for node in graph_module.graph.nodes: + if node.target != pred_target: + continue + if node.next.target == succ_target: + return False + return True diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index e23e53bd2b1..ab23149e60d 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -6,43 +6,41 @@ # pyre-strict -from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, List, Optional, Type import torch import torch.fx import torch.utils._pytree as pytree -from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph +from executorch.backends.cadence.aot.fuse_ops import ( + CadenceFuseOpsInGraph, + FuseFullThenReshapePass, + FuseTransposeOpPairsPass, +) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, create_cadence_pass_filter, register_cadence_pass, ) + +from executorch.backends.cadence.aot.remove_ops import ( + CadenceRemoveNops, + RemoveNopSliceOrViewOpPass, + RemoveRedundantOps, +) +from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph +from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph -from executorch.backends.cadence.aot.utils import get_edge_overload_packet -from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass from executorch.exir.passes.spec_prop_pass import SpecPropPass -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveCloneOpsTransformImported(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - finalize_passes: List[PassType] = [ - RemoveCloneOpsTransform(), - ] - result = PassManager(passes=finalize_passes)(graph_module) - dead_code_elimination_pass(result.graph_module) - return result - - @register_cadence_pass(CadencePassAttribute(opt_level=0)) class InitializePipeline(ExportPass): """ - Initialize the Jarvis pipeline. This should invariably be the first pass to + Initialize the pass pipeline. This should invariably be the first pass to run. """ @@ -56,7 +54,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=0)) class FinalizePipeline(ExportPass): """ - The final cleanup pass after running the Jarvis pipeline. + The final cleanup pass after running the pass pipeline. """ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @@ -73,298 +71,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: Argument = Any # pyre-ignore -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2QuantWithCadenceQuantPass(ExportPass): - """ - Replace the pt2 quantization ops with cadence quantization ops. - We do not link kernels to the PT2 quantization ops, so we need to - replace them with cadence ops at all optimization levels. - """ - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.cadence.quantize_per_tensor.default, - args, - kwargs, - meta, - ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2DequantWithCadenceDequantPass(ExportPass): - """ - Replace the pt2 dequantization ops with cadence dequantization ops. - We do not link kernels to the PT2 quantization ops, so we need to - replace them with cadence ops at all optimization levels. - """ - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.cadence.dequantize_per_tensor.default, - args, - kwargs, - meta, - ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): - """ - aten.scalar_tensor can be replaced by aten.full with a shape of [1]. - """ - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, - torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): - """ - When the shape is static, replace squeeze_copy and unsqueeze_copy ops with - view_copy op - """ - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, - # which allows us to cover all overloads. - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.squeeze_copy, - exir_ops.edge.aten.unsqueeze_copy, - }: - return super().call_operator(op, args, kwargs, meta) - # Get the output tensor shape - out_shape = meta["val"].shape - - # Bail out if any dim is not an int (dynamic shape) - for dim in list(out_shape): - if not isinstance(dim, int): - return super().call_operator(op, args, kwargs, meta) - - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveZeroSizedCatArgsPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.cat.default: - return super().call_operator(op, args, kwargs, meta) - - # Remove any zero-sized tensor arg to form a new args list. - cat_inputs: list[ProxyValue] = [] - for arg in cast(Sequence[ProxyValue], args[0]): - if arg.to_tensor().numel() > 0: - cat_inputs.append(arg) - - # If all the tensors were empty, we just return an empty tensor with - # the right shape. - if not cat_inputs: - empty_shape = meta["val"].shape - dtype = meta["val"].dtype - return super().call_operator( - exir_ops.edge.aten.full.default, - (tuple(empty_shape), 0), - {"dtype": dtype}, - meta, - ) - - # If there was only one tensor in the cat_inputs list, - # we can safely erase this cat op. - if len(cat_inputs) == 1: - return cat_inputs[0] - - # Otherwise, we replace args[0] with cat_inputs. - new_args = list(args) - new_args[0] = cat_inputs - return super().call_operator(op, tuple(new_args), kwargs, meta) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveNopExpandOpPass(ExportPass): - """ - For an expand op, if the operator shape matches the expand shape, then the - expand is a nop. - """ - - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.expand_copy, - exir_ops.edge.aten.expand, - }: - return super().call_operator(op, args, kwargs, meta) - - # Parse the args, and check for nop condition - arg0 = cast(ProxyValue, args[0]) - arg1 = cast(Sequence[int], args[1]) - in_tensor = arg0.to_tensor() - if list(in_tensor.shape) == list(arg1): - return arg0 - - return super().call_operator(op, args, kwargs, meta) - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): - """ - A where op with a logical_not and a boolean tensor can be replaced - by a where op with flipped inputs and the initial boolean tensor. - """ - - def replace_logical_nop_where_with_where( - self, graph_module: torch.fx.GraphModule - ) -> None: - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in where nodes - if node.target != exir_ops.edge.aten.where.self: - continue - - # If the third arg is not a logical_not, bail. - if node.args[0].target != exir_ops.edge.aten.logical_not.default: - continue - - # Get the third arg node and its input - logical_not_node = node.args[0] - logical_not_input_tensor = ( - logical_not_node.args[0].to_tensor() - if isinstance(logical_not_node.args[0], ProxyValue) - else logical_not_node.args[0] - ) - - # If the logical_not input is not a boolean tensor, bail. - if logical_not_input_tensor.meta["spec"].dtype != torch.bool: - continue - - # Replace the where op with another one, flipping the inputs and using the boolean - # tensor from logical_not. - with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.where.self, - args=(logical_not_node.args[0], node.args[2], node.args[1]), - ) - # Replace all the uses - node.replace_all_uses_with(linear_node) - - graph_module.recompile() - graph_module.graph.eliminate_dead_code() - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_logical_nop_where_with_where(graph_module) - result = super().call(graph_module) - return result - - -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep - """ - Replace _safe_softmax with _softmax - """ - - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != torch.ops.aten._safe_softmax.default: - return super().call_operator(op, args, kwargs, meta) - - # Add False for the half_to_float argument of softmax - softmax_args = list(args) + [False] - - return super().call_operator( - torch.ops.aten._softmax.default, - tuple(softmax_args), - kwargs, - meta, - ) - - def get_passes_in_default_order() -> List[Type[PassType]]: passes = [ InitializePipeline, - RemoveZeroSizedCatArgsPass, - ReplaceLogicalNotBooleanWhereWithWherePass, - ReplaceScalarTensorWithFullPass, - RemoveCloneOpsTransformImported, - RemoveNopExpandOpPass, + RemoveRedundantOps.passes, + CadenceReorderOpsInGraph.passes, + # Phase ordering: remove -> fusion -> replacement passes. + CadenceRemoveNops.passes, CadenceFuseOpsInGraph.passes, - ReplaceSqueezeAndUnsqueezeWithViewPass, - ReplacePT2QuantWithCadenceQuantPass, - ReplacePT2DequantWithCadenceDequantPass, + CadenceReplaceOpsInGraph.passes, CadenceSimplifyOpsInGraph.passes, - # TODO: add the rest of the passes here. - # InitializePipeline, - # RemoveRedundantOps.passes, - # ReorderOpsInGraph.passes, - # RemoveJarvisNops.passes, - # CadenceFuseOpsInGraph.passes, - # ReplaceOpsInGraph.passes, - # SimplifyOpsInGraph.passes, - # FinalizePipeline, - # FuseFullThenReshapePass, - # FuseTransposeOpPairsPass, - # RemoveNopSliceOrViewOpPass, + FinalizePipeline, + FuseFullThenReshapePass, + FuseTransposeOpPairsPass, + RemoveNopSliceOrViewOpPass, ] return pytree.tree_flatten(passes)[0] diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py new file mode 100644 index 00000000000..d2251bd9c08 --- /dev/null +++ b/backends/cadence/aot/remove_ops.py @@ -0,0 +1,731 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + + +# This file contains functions to remove operators from the graph. The removed +# ops should belong to either of the following categories: +# 1. The op should be redundant for inference (e.g., dropout). Such ops are grouped +# together in 'RemoveRedundantOps'. Anyone running inference can add this class +# in their pass list, and it should semantic-preserving transformation. +# 2. The op should be redundant for Jarvis (e.g., contiguous). Such ops are grouped +# together in 'CadenceRemoveNops'. The ops removed in this class might not be nop +# in a context outside of Jarvis', so exercise caution while invoking this in a +# pass list outside of Jarvis. + +import itertools +import logging +from dataclasses import dataclass, field +from typing import Callable, cast, Dict, List, Optional, Sequence + +import torch +import torch.fx +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + register_cadence_pass, +) + +from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass +from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.pass_manager import PassManager, PassType +from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.passes.spec_prop_pass import SpecPropPass +from torch.fx.node import Argument + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveCloneOpsTransformImported(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + finalize_passes: List[PassType] = [ + RemoveCloneOpsTransform(), + ] + result = PassManager(passes=finalize_passes)(graph_module) + dead_code_elimination_pass(result.graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveDetachCopyPass(ExportPass): + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.detach_copy.default: + return super().call_operator(op, args, kwargs, meta) + + assert len(args) == 1 + return cast(ProxyValue, args[0]) + + +# The following class consolidates passes to remove ops that are redundant: +# either by the virtue of the operation they perform, or redundant in the +# context of inference. +class RemoveRedundantOps: + passes = [ + RemoveDetachCopyPass, + ] + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveZeroSizedCatArgsPass(ExportPass): + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.cat.default: + return super().call_operator(op, args, kwargs, meta) + + # Remove any zero-sized tensor arg to form a new args list. + cat_inputs: list[ProxyValue] = [] + for arg in cast(Sequence[ProxyValue], args[0]): + if arg.to_tensor().numel() > 0: + cat_inputs.append(arg) + + # If all the tensors were empty, we just return an empty tensor with + # the right shape. + if not cat_inputs: + empty_shape = meta["val"].shape + dtype = meta["val"].dtype + return super().call_operator( + exir_ops.edge.aten.full.default, + (tuple(empty_shape), 0), + {"dtype": dtype}, + meta, + ) + + # If there was only one tensor in the cat_inputs list, + # we can safely erase this cat op. + if len(cat_inputs) == 1: + return cat_inputs[0] + + # Otherwise, we replace args[0] with cat_inputs. + new_args = list(args) + new_args[0] = cat_inputs + return super().call_operator(op, tuple(new_args), kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveNopExpandOpPass(ExportPass): + """ + For an expand op, if the operator shape matches the expand shape, then the + expand is a nop. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if get_edge_overload_packet(op) not in { + exir_ops.edge.aten.expand_copy, + exir_ops.edge.aten.expand, + }: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args, and check for nop condition + arg0 = cast(ProxyValue, args[0]) + arg1 = cast(Sequence[int], args[1]) + in_tensor = arg0.to_tensor() + if list(in_tensor.shape) == list(arg1): + return arg0 + + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveToOpsPass(ExportPass): + # aten.to.* as of now are all nops for Jarvis + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in ( + exir_ops.edge.aten.to.dtype, + exir_ops.edge.aten.to.dtype_layout, + ): + return super().call_operator(op, args, kwargs, meta) + + logging.debug(f"Erasing to.dtype node (target = {op})") + return cast(ProxyValue, args[0]) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveZeroSizedConstantPadNd(ExportPass): + def call_operator( + self, + op, # pyre-ignore + args: tuple[ProxyValue, tuple[int, ...], Argument], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.constant_pad_nd.default: + return super().call_operator(op, args, kwargs, meta) + + input_tensor = args[0] + padding = args[1] + + if any(x != 0 for x in padding): + return super().call_operator(op, args, kwargs, meta) + + logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") + return input_tensor + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopSliceOrViewOpPass(ExportPass): + """ + Remove slice ops that are more like views, and view ops that do not change the shape + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.view_copy.default, + }: + return super().call_operator(op, args, kwargs, meta) + + arg0 = cast(ProxyValue, args[0]) + out_shape = meta["val"].shape + + # If both arg_shape and out_shape are the same, this slice is a nop + return ( + arg0 + if arg0.to_tensor().shape == out_shape + else super().call_operator(op, args, kwargs, meta) + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopLinalgVectorNormOpPass(ExportPass): + """ + If the norm is applied over a dimension that is size 1, it can be eliminated. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.linalg_vector_norm.default, + exir_ops.edge.cadence.linalg_vector_norm.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # If the op has three args or less, it can't be a nop + if len(args) <= 3: + return super().call_operator(op, args, kwargs, meta) + # If dim is None, or keepdim is False, it is not a nop + dim = cast(Optional[tuple[int, ...]], args[2]) + keepdim = cast(bool, args[3]) + if dim is None or not keepdim: + return super().call_operator(op, args, kwargs, meta) + + # If the norm has 4 args and keepdim is True, check if dim is not None + # and if the dimensions in dim are size 1. If not, the norm is not a nop. + t = cast(ProxyValue, args[0]) + shape = t.to_tensor().shape + if len(args) < 4: + for d in dim: + if shape[d] != 1: + return super().call_operator(op, args, kwargs, meta) + + return t + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopSelectOpPass(ExportPass): + """ + A select op that selects from a dimension that is size 1 can be eliminated + in a few cases. For example, + ``` + x = view (x, [1, 3, 16]) + y = select(x, 0, 0) + z = add(m, y) + ``` + The special thing about this pattern is the add op, which allows + broadcasting. So adding an operand with shape [3, 16] is the same as + adding an operand with shape [1, 3, 16]. Therefore, if m has the same + shape as x, then this select op is a nop, and can be eliminated: + ``` + x = view (x, [1, 3, 16]) + z = add(x, m) + ``` + """ + + # A set of binary operators that could require broadcasting, and are + # critical to this transformation if their operand is select op. + binary_broadcast_ops: set[EdgeOpOverload] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + } + + def __init__(self) -> None: + super().__init__() + self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {} + + # For select, view, or any op in binary_broadcast_ops, record the shapes of + # input and output tensors. + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + res = super().call_operator(op, args, kwargs, meta) + # Unary ops: input and output + if op in { + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.view_copy.default, + }: + arg0 = cast(ProxyValue, args[0]) + self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape) + # Binary ops: two inputs, output shape can be inferred + elif op in self.binary_broadcast_ops: + arg0 = cast(ProxyValue, args[0]) + arg1 = cast(ProxyValue, args[1]) + self.op_sizes[res.node.name] = ( + arg0.to_tensor().shape, + arg1.to_tensor().shape, + ) + return res + + # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops, + # and check if their arg is a select op. + def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None: + for sel_node in graph_module.graph.nodes: + # We are only interested in select ops + if sel_node.target != exir_ops.edge.aten.select_copy.int: + continue + # The shape of the input/output operands for this select op should + # have been precomputed. + assert sel_node.name in self.op_sizes + (sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name] + # Get the select dimension + sel_dim = ( + sel_node.args[1] + if sel_node.args[1] >= 0 + else sel_node.args[1] + len(sel_in_shape) + ) + # If the input size along select dimension is not 1, bail. + if sel_in_shape[sel_dim] != 1: + continue + + # Get all the users of the select op that are either view, or + # binary_broadcast_ops. + users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes] + sel_in = sel_node.args[0] + + # Iterate over the users of select op, and remove the use of the + # select op in the user if feasible. + for node in users: + args = list(node.args) + for idx, sel_arg in enumerate(args): + # Check if the arg is the select op + if sel_arg != sel_node: + continue + # If the input of select has the same shape as the other arg + # of the binary op, the select op can be bypassed. + if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]: + args[idx] = sel_in + # update the node's args + node.args = tuple(args) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + result = SpecPropPass()(graph_module) + assert result is not None + result = super().call(result.graph_module) + self.eliminate_nop_select_op(result.graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveCloneOpPass(ExportPass): + # If the op is a clone op, return the input and eliminate the op + def call_operator( + self, + op, # pyre-ignore + args: tuple[ProxyValue], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.clone.default: + return super().call_operator(op, args, kwargs, meta) + + return args[0] + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveContiguousOpPass(ExportPass): + """ + This is based on the assumption that all tensors are contiguous in ExecuTorch + and after cadence passes, and we should revisit this if that assumption is no longer true. + This causes the model to not be runnable with the arguments given to the + original graph module. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.contiguous.default: + return super().call_operator(op, args, kwargs, meta) + + assert len(args) == 1 + return cast(ProxyValue, args[0]) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveAliasCopyOpPass(ExportPass): + """ + + alias_copy is a no-op for Jarvis and can be removed. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.alias_copy.default: + return super().call_operator(op, args, kwargs, meta) + + assert len(args) == 1 + return cast(ProxyValue, args[0]) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopRequantizeOpPass(ExportPass): + """ + For a requantize op, if the following three conditions are satisfied: + 1. the in_scale matches the out_scale + 2. the in_zero_point matches the out_zero_point + 3. the dtypes of the input and output tensors are the same + then the requantize op is redundant, and can be eliminated + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.cadence.requantize.default: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args + (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( + tuple[ProxyValue, int, float, int, float, torch.dtype], args + ) + in_dtype = X.to_tensor().dtype + # Check the three conditions + if ( + in_scale == out_scale + and in_zero_point == out_zero_point + and in_dtype == out_dtype + ): + return cast(ProxyValue, args[0]) + + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopMulOpPass(ExportPass): + """ + If a mul op is multiplying two tensors with the same shape and one + of those tensors is all zeros, return the zero tensor instead. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.mul.Tensor: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args + (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + + # Check if both inputs have the same shape + if input1.to_tensor().shape != input2.to_tensor().shape: + return super().call_operator(op, args, kwargs, meta) + + # Check if one of the inputs is a zero tensor + if input1.node.target == exir_ops.edge.aten.full.default: + if input1.node.args[1] == 0: + return input1 + elif input2.node.target == exir_ops.edge.aten.full.default: + if input2.node.args[1] == 0: + return input2 + + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveNopAddOpPass(ExportPass): + """ + If an add op is adding two tensors with the same shape and one + of those tensors is all zeros, return the other tensor instead. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.add.Tensor: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args + (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + + # Check if both inputs have the same shape + if input1.to_tensor().shape != input2.to_tensor().shape: + return super().call_operator(op, args, kwargs, meta) + + # Check if one of the inputs is a zero tensor + if input1.node.target == exir_ops.edge.aten.full.default: + if input1.node.args[1] == 0: + return input2 + elif input2.node.target == exir_ops.edge.aten.full.default: + if input2.node.args[1] == 0: + return input1 + + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemovePermutesAroundElementwiseOps(ExportPass): + """ + Looks for subgraphs of elementwise ops sandwiched between permutes and removes those + permutes if possible. This pass is targeted at models where delegated subgraphs + must be in NHWC format, so there's usually a to_NHWC permute before each delegate and + a to_NCHW permute after it. If all the ops between two delegates are elementwise ops + then these permutes can be safely removed. + Allows special handling for certain non-elementwise ops that can be easily updated based on + the permute's parameter, such as mean and cat + """ + + @dataclass() + class Subgraph: + """ + Keeps track of nodes grouped as a subgraph between two sets of permutes + """ + + start_permutes: set[torch.fx.Node] = field(default_factory=set) + end_permutes: set[torch.fx.Node] = field(default_factory=set) + intermediate_nodes: set[torch.fx.Node] = field(default_factory=set) + is_valid: bool = True + + elementwise_ops: set[EdgeOpOverload] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.cat.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + } + + # must be initialized in the constructor + special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {} + + to_NCHW = [0, 3, 1, 2] + to_NHWC = [0, 2, 3, 1] + + def __init__(self) -> None: + super().__init__() + self.visited: set[object] = set() + self.special_handling = { + exir_ops.edge.aten.mean.dim: self.handle_mean_dim, + exir_ops.edge.aten.cat.default: self.handle_cat, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.visited = set() + for node in graph_module.graph.nodes: + sg = self.Subgraph() + self.start_search(node, sg) + if self.is_valid_subgraph(sg): + logging.debug(f"Found valid subgraph: {sg}") + self.handle_subgraph(graph_module, sg) + + result = super().call(graph_module) + return result + + def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None: + assert mean_dim.target == exir_ops.edge.aten.mean.dim + args = list(mean_dim.args) + args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])] + mean_dim.args = tuple(args) + + def handle_cat(self, cat: torch.fx.Node) -> None: + assert cat.target == exir_ops.edge.aten.cat.default + args = list(cat.args) + args[1] = self.to_NCHW[cast(int, args[1])] + cat.args = tuple(args) + + def is_valid_subgraph(self, sg: Subgraph) -> bool: + return ( + sg.is_valid + and len(sg.start_permutes) > 0 + and len(sg.end_permutes) > 0 + and len(sg.intermediate_nodes) > 0 + ) + + def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None: + for permute in itertools.chain(sg.start_permutes, sg.end_permutes): + permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6] + + for node in sg.intermediate_nodes: + if node.target in self.special_handling: + self.special_handling[node.target](node) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None: + if node in self.visited: + return + + if self.is_starting_permute(node): + sg.start_permutes.add(node) + self.visited.add(node) + for user in node.users: + self.search_down(user, sg) + + def search_up(self, node: object, sg: Subgraph) -> None: + # non-nodes can be ignored. These would be arguments like integers or lists + # of integers, which don't affect the subgraph validity or inclusion set. + if not isinstance(node, torch.fx.Node): + return + + if node.op == "placeholder": + # If we reach a placeholder or other terminal node without encountering + # a start permute, then the subgraph is invalid. + # This could be because in the add(x, y) case where x is permuted and + # y is a graph input, we can't remove the permute on x because it might + # become two different shapes that don't broadcast together. + # TODO: Adding a permute on y could be the more optimal solution, + # but perhaps not in all cases, say if x is small and y is very large. + # This transform prefers to be safe over optimal for now. + sg.is_valid = False + return + + if node in self.visited: + return + + self.visited.add(node) + + if self.is_starting_permute(node): + sg.start_permutes.add(node) + for user in node.users: + self.search_down(user, sg) + else: + self.traverse_intermediate_node(node, sg) + + def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None: + if node in self.visited or self.is_starting_permute(node): + return + + self.visited.add(node) + + if self.is_ending_permute(node): + sg.end_permutes.add(node) + for arg in node.args: + if isinstance(arg, list): + for elem in arg: + self.search_up(elem, sg) + else: + self.search_up(arg, sg) + else: + self.traverse_intermediate_node(node, sg) + + def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: + if node.target in self.elementwise_ops: + sg.intermediate_nodes.add(node) + for arg in node.args: + if isinstance(arg, list): + for elem in arg: + self.search_up(elem, sg) + else: + self.search_up(arg, sg) + + for user in node.users: + self.search_down(user, sg) + + else: + sg.is_valid = False + + def is_starting_permute(self, node: torch.fx.Node) -> bool: + return ( + node.target == exir_ops.edge.aten.permute_copy.default + and cast(list[int], node.args[1]) == self.to_NCHW + ) + + def is_ending_permute(self, node: torch.fx.Node) -> bool: + return ( + node.target == exir_ops.edge.aten.permute_copy.default + and cast(list[int], node.args[1]) == self.to_NHWC + ) + + +# The following class consolidates functions to remove ops that are redundant +# in Jarvis. Currently, each function in this class iterates over each node of +# the graph module once. In future, we could consolidate them into a monolithic +# function. +class CadenceRemoveNops: + passes = [ + SimplifySliceOpPass, + RemoveCloneOpsTransformImported, + RemoveToOpsPass, + RemoveNopRequantizeOpPass, + RemoveZeroSizedCatArgsPass, + RemoveNopSliceOrViewOpPass, + RemoveNopExpandOpPass, + RemoveZeroSizedConstantPadNd, + RemoveCloneOpPass, + RemoveContiguousOpPass, + RemoveAliasCopyOpPass, + RemoveNopMulOpPass, + RemoveNopAddOpPass, + RemoveNopLinalgVectorNormOpPass, + ] diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py new file mode 100644 index 00000000000..313edae5f4c --- /dev/null +++ b/backends/cadence/aot/reorder_ops.py @@ -0,0 +1,824 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + + +# This file contains all the functions that reorder ops in the graph module. + +import copy +from collections import defaultdict +from math import prod +from typing import cast, DefaultDict, List, Set, Tuple + +import torch +import torch.fx +from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + get_overload_packet, + register_cadence_pass, +) +from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import num_bytes_from_shape_and_dtype + +# A list of ops that can be trivially quantized +trivially_quantizable_ops_overloadpkt = { + torch.ops.aten.slice_copy, + torch.ops.aten.slice, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten.clone, + torch.ops.aten.transpose_copy, + torch.ops.aten.transpose, + torch.ops.aten.permute_copy, + torch.ops.aten.permute, + torch.ops.aten.squeeze_copy, + torch.ops.aten.squeeze, + torch.ops.aten.unsqueeze_copy, + torch.ops.aten.unsqueeze, + torch.ops.aten.chunk, + torch.ops.aten.contiguous, + torch.ops.aten.select_copy, + exir_ops.edge.aten.slice_copy, + exir_ops.edge.aten.view_copy, + exir_ops.edge.aten.clone, + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + exir_ops.edge.aten.squeeze_copy, + exir_ops.edge.aten.unsqueeze_copy, + exir_ops.edge.aten.unfold_copy, + exir_ops.edge.aten.chunk, + exir_ops.edge.aten.contiguous, + exir_ops.edge.aten.select_copy, +} + +# slice-equivalent ops +slice_or_select_overloadpkt = { + torch.ops.aten.slice_copy, + torch.ops.aten.select_copy, + exir_ops.edge.aten.slice_copy, + exir_ops.edge.aten.select_copy, +} + + +@register_cadence_pass(CadencePassAttribute(opt_level=2)) +class AdvanceQuantizeOpAboveDefInBranchPass(ExportPass): + """ + If the graph is branched with the following pattern: + I = ... + S1 = slice(I) + Q1 = quantize(S1) + S2 = slice(I) + Q2 = quantize(S2) + S3 = slice(I) + Q3 = quantize(S3) + ... + such that the elements in the slices S1 + S2 + S3 is greater than I, + we can advance the quantize above their defs (i.e., all the slice nodes), + and reorder the pattern to the following: + I = ... + Q1 = quantize(I) + S1 = slice(Q1) + Q1 = requantize(S1) + S2 = slice(Q1) + Q2 = requantize(S2) + S3 = slice(Q1) + Q3 = requantize(S3) + ... + Note that the other passes won't do this transformation because they expect + a linear chain of def-use, which is not true here; the uses of I are + branched. + """ + + def __init__(self): + super().__init__() + self.graph_module = None + + # Starting at node, iterate through its successors, bypassing any trivially + # quantizable op. If all the descendents are quantize ops, return them. + def get_descendent_quant_ops(self, node: torch.fx.Node) -> List[torch.fx.Node]: + # The list of quant ops that are descendents of node, such that the only + # nodes in the path from node --> quant are trivially quantizable ops. + descendent_quant_ops = [] + # The list of trivially quantizable ops in the path from node --> quant op. + trivial_quantized_ops = [] + + users = list(node.users.keys()) + while users: + user = users.pop(0) + user_target = get_overload_packet(user.target) + # Record a quant op successor + if user_target in { + torch.ops.quantized_decomposed.quantize_per_tensor, + exir_ops.edge.quantized_decomposed.quantize_per_tensor, + }: + descendent_quant_ops.append(user) + # If the successor is a trivially quantizable op, consider its users + # instead. + elif user_target in trivially_quantizable_ops_overloadpkt: + trivial_quantized_ops.append(user) + users.extend(list(user.users.keys())) + # Otherwise all successors of node are not quant op, so break the loop. + else: + descendent_quant_ops.clear() + break + + # If all the nodes in trivial_quantize_ops of the node were slice ops, + # ensure that the advance is still profitable. + if descendent_quant_ops and all( + get_overload_packet(x.target) in slice_or_select_overloadpkt + for x in trivial_quantized_ops + ): + # Profitability metric: the sum of all the output slices must be at + # least half the input node slice. + slice_sizes = [ + prod(list(y)) + for x in trivial_quantized_ops + if (y := get_shape(self.graph_module, x)) is not None + ] + node_shape = get_shape(self.graph_module, node) + node_size = prod(list(node_shape)) if node_shape is not None else 0 + if node_size > 2 * sum(slice_sizes): + descendent_quant_ops.clear() + + return descendent_quant_ops + + def advance_quantize_op(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in call functions and placeholders + if node.op not in {"placeholder", "call_function"}: + continue + # If the node is trivially quantizable, skip it + if ( + get_overload_packet(node.target) + in trivially_quantizable_ops_overloadpkt + ): + continue + # Get the descendent quant ops that are connected to the current + # node via trivially quantizable ops. + descendent_quant_ops = self.get_descendent_quant_ops(node) + if not descendent_quant_ops: + continue + + # Get the insertion point below which we need to insert anything. + # if node is a placeholder, we will only insert a new node after + # all the placeholders in the graph. + insertion_pt = ( + get_placeholders(graph)[-1] if node.op == "placeholder" else node + ) + + # If the node only has a single quant op as descendent, we can + # simply hoist the quant op below the current node as its single + # child. + if len(descendent_quant_ops) == 1: + quant_node = descendent_quant_ops.pop() + # Replace the uses of quant node with its predecessor + quant_node.replace_all_uses_with(quant_node.args[0]) # pyre-fixme[6] + # Hoist the quant node after the current node. Make sure that + # the insertion is after placeholders + with graph.inserting_after(insertion_pt): + dom_quant_args = (node,) + quant_node.args[1:] + dom_quant_node = graph.call_function( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + dom_quant_node.meta = node.meta + node.replace_all_uses_with(dom_quant_node) + dom_quant_node.args = dom_quant_args + graph.erase_node(quant_node) + continue + + # Otherwise we have the quant descendents. Cluster them into sets + # that have the same scale, zero_point, and dtype. We use quant_dict + # for the clustering + quant_dict: DefaultDict[Tuple, int] = defaultdict(int) + for quant_node in descendent_quant_ops: + quant_dict[quant_node.args[1:]] += 1 + rep_args = sorted(quant_dict.keys(), key=lambda x: x[1]).pop() + + # Create a new quant node that dominates all the nodes in + # descendent_quant_ops. Make sure that the insertion is after + # all the placeholders. + with graph.inserting_after(insertion_pt): + dom_quant_args = (node,) + rep_args + dom_quant_node = graph.call_function( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + dom_quant_node.meta = node.meta + node.replace_all_uses_with(dom_quant_node) + dom_quant_node.args = dom_quant_args + + # Finally, convert each of the quant node to a dequant/quant pair that + # requantizes the data flowing through dom_quant_node. + # TODO: Once requantize is implemented for PT2, replace the + # dequant/quant pair here with a single requantize node + for quant_node in descendent_quant_ops: + with graph.inserting_before(quant_node): + dequant_node = graph.call_function( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + dequant_node.args = (quant_node.args[0],) + rep_args + quant_node.args = (dequant_node,) + quant_node.args[1:] + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.graph_module = graph_module + self.advance_quantize_op(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class AdvanceQuantizeOpAboveDefChainPass(ExportPass): + """ + If the input to quantize op is linear chain of view, transpose, permute, or + slice ops that are trivially quantized, we can convert the pattern + view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to + quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8). + The benefit of such reordering is that the view/transpose/permute/slice + will move far less data. + """ + + def __init__(self): + super().__init__() + self.graph_module = None + + # Return true if advancing the quantize node is feasible + def advancing_feasible(self, quant_node: torch.fx.Node): + assert quant_node.op == "call_function" and len(quant_node.args) >= 1 + # Get the input of the quant node. Only proceed if it's a torch node. + inp = quant_node.args[0] + if not isinstance(inp, torch.fx.Node): + return False + + # Return false if the input to the quantize node is (1) not trivially + # quantizable, or (2) has more than one user. + inp_users = list(inp.users.keys()) + inp_overloadpkt = None + if isinstance(inp.target, EdgeOpOverload): + inp_overloadpkt = get_edge_overload_packet(inp.target) + else: + inp_overloadpkt = get_overload_packet(inp.target) + + if ( + inp_overloadpkt not in trivially_quantizable_ops_overloadpkt + or len(inp_users) != 1 + ): + return False + + # Advancing quantize op above slice nodes is tricky. If we advance the + # quantize node above slice, then we will quantize the input to the slice + # op, which can be expensive. We only bypass nop slice at present. + if inp_overloadpkt in slice_or_select_overloadpkt: + sliced_tensor = inp.args[0] + assert isinstance(sliced_tensor, torch.fx.Node) + slice_input_shape = get_shape(self.graph_module, sliced_tensor) + slice_output_shape = get_shape(self.graph_module, inp) + # If we could not glean the shapes, or the slice op is a nop, bail + if ( + slice_output_shape is None + or slice_input_shape is None + or prod(list(slice_output_shape)) < prod(list(slice_input_shape)) + ): + return False + + # All the conditions satisfied, we advance. + return True + + def advance_quantize_op(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in reversed(graph.nodes): + if get_overload_packet(node.target) not in ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor, + ): + continue + + if not self.advancing_feasible(node): + continue + + trivially_quantizable_op = node.args[0] + # The input to the quant node must now be the input to the trivially + # quantizable op. + quant_args = list(node.args) + quant_args[0] = trivially_quantizable_op.args[0] + + # Insert the new quant node with updated args before the current + # quant node. + with graph.inserting_before(node): + quant_node = graph.call_function(node.target, args=tuple(quant_args)) + quant_node.meta = node.meta + # Move the trivially quantizable node after the quant node + with graph.inserting_after(node): + tq_args = list(trivially_quantizable_op.args) + tq_args[0] = quant_node + tq_node = graph.call_function( + trivially_quantizable_op.target, + args=tuple(tq_args), + kwargs=trivially_quantizable_op.kwargs, + ) + tq_node.meta = trivially_quantizable_op.meta + # Replace all uses of node with newly created tq_node + node.replace_all_uses_with(tq_node) + # We can safely remove the quant node and trivially quantizable op + graph.erase_node(node) + graph.erase_node(trivially_quantizable_op) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.graph_module = graph_module + self.advance_quantize_op(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class PostponeDequantizeOpBelowUseChainPass(ExportPass): + """ + If the consumer of dequantize is a linear chain of view, transpose, permute, + or slice ops that are trivially quantized, we can convert the pattern + dequantize(int8/uint8) -> view/transpose/permute/slice(fp32) to + view/transpose/permute/slice(int8/uint8) -> dequantize(int8/uint8) + The benefit of such reordering is that the view/transpose/permute/slice + will move far less data. + """ + + def __init__(self): + super().__init__() + self.graph_module = None + + # Return true if postponing the dequantize node is feasible + def postponing_feasible(self, dequant_node: torch.fx.Node): + users = list(dequant_node.users.keys()) + # Check if the dequantize op has a single user, and that user is + # trivially quantizable. + trivially_quantizable_users = all( + get_overload_packet(user.target) in trivially_quantizable_ops_overloadpkt + for user in users + ) + if len(users) == 1: + return trivially_quantizable_users + + # Otherwise check if all the users are slice op + if not all( + get_overload_packet(user.target) in slice_or_select_overloadpkt + for user in users + ): + return False + + dequant_shape = get_shape(self.graph_module, dequant_node) + slice_shapes = [ + shape + for user in users + if (shape := get_shape(self.graph_module, user)) + and ( + # skip slices that are the size of the sliced tensor itself. + # They should technically get removed in the later passes as nop. + shape is None + or dequant_shape is None + or prod(list(shape)) != prod(list(dequant_shape)) + ) + ] + + if dequant_shape is not None and all( + shape is not None for shape in slice_shapes + ): + dequant_bytes = num_bytes_from_shape_and_dtype(dequant_shape, torch.float32) + slice_bytes = sum( + [ + num_bytes_from_shape_and_dtype(shape, torch.float32) + for shape in slice_shapes + ] + ) + if slice_bytes <= dequant_bytes: + return True + + # If the users of each slice op is quantize op, then we can postpone + # dequantize, and convert slice -> dequantize -> quantize to + # slice -> requantize. + users = [x for y in users for x in y.users if x.op != "output"] + return all( + get_overload_packet(x.target) + in { + exir_ops.edge.quantized_decomposed.quantize_per_tensor, + exir_ops.edge.quantized_decomposed.quantize_per_channel, + } + for x in users + ) + + def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool: + # Different supported dequant ops have their own default variants + packet_to_overload_map = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor: "default", + exir_ops.edge.quantized_decomposed.dequantize_per_channel: "default", + } + graph = graph_module.graph + modified = False + for node in graph.nodes: + overload_packet = get_overload_packet(node.target) + if ( + overload_packet not in packet_to_overload_map.keys() + or not self.postponing_feasible(node) + ): + continue + + for user in node.users: + with graph.inserting_after(user): + dequant_node = graph.call_function( + getattr( + overload_packet, packet_to_overload_map[overload_packet] + ), + args=(user, *node.args[1:]), + ) + dequant_node.meta = user.meta.copy() + # Remove meta["debug_handle"] on new node. Reassign it at the + # caller level by calling generate_missing_debug_handles + dequant_node.meta.pop("debug_handle") + user.replace_all_uses_with(dequant_node) + dequant_node.args = (user, *node.args[1:]) + + pred = node.args[0] + node.replace_all_uses_with(pred) + graph.erase_node(node) + modified = True + + graph_module.recompile() + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # The logic in postpone_dequantize_op that handles branching checks the shape + # of the dequant node, which isn't available if that node was already postponed + # in the same pass invokation. The shape information is recreated by tracing in + # super().call(), meaning that every branch in the graph that we wish to postpone + # dequant past requires retracing. We iterate the pass until it no longer modifies + # the graph (up to 3 times max, to avoid potential infinite loops) + self.graph_module = graph_module + iter_count = 0 + modified = True + + while modified and iter_count < 3: + modified = self.postpone_dequantize_op(self.graph_module) + self.graph_module = super().call(self.graph_module).graph_module + iter_count += 1 + + return super().call(self.graph_module) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class SinkOpsCloserToUsePass(ExportPass): + """ + Assume that the dequantize op D = dequantize(I) has only a single user. + If the current graph looks like + I = ...; + D = dequantize(I); + ... + Y = use(D); + then we can postpone the dequantize op closer to its use, and convert the + graph to: + I = ...; + ... + D = dequantize(I); + Y = use(D); + + The transformation is valid since D had a single user. The benfit comes from + the fact that now we have I in the live range instead of D, which has a + much smaller size. + """ + + sinkable_ops: Set[EdgeOpOverload] = { + exir_ops.edge.aten.dequantize, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel, + } + + def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + # We are only interested in sinkable nodes + sinkable_nodes = [ + node + for node in graph.nodes + if isinstance(node.target, EdgeOpOverload) + and get_edge_overload_packet(node.target) in self.sinkable_ops + ] + for node in sinkable_nodes: + # The sinkable node must have a single user + users = list(node.users.keys()) + if len(users) != 1: + continue + + # Insert the dequant node just before its user + with graph.inserting_before(users[0]): + new_node = graph.call_function( + node.target, args=node.args, kwargs=node.kwargs + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + graph_module.recompile() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.sink_ops_closer_to_use(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class HoistOpsCloserToDefPass(ExportPass): + """ + Assume that the input I to a quantize op Q = quantize(I) has only a single + use, the quantize node itself. + If the current graph looks like + I = ...; + ... + Q = quantize(I); + X = use(Q); + then we can hoist the quantize op closer to its def, and convert the + graph to: + I = ...; + Q = quantize(I); + ... + X = use(Q); + + The transformation is valid since I had a single user. The benefit comes from + the fact that now we have Q in the live range instead of I, which has a + much smaller size. The same transformation also applies to slice/select op. + """ + + hoistable_ops: Set[EdgeOpOverload] = { + exir_ops.edge.quantized_decomposed.quantize_per_tensor, + exir_ops.edge.aten.slice_copy, + exir_ops.edge.aten.select_copy, + } + + def hoist_ops_closer_to_def(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + # We are only interested in hoistable nodes + hoistable_nodes = [ + node + for node in graph.nodes + if isinstance(node.target, EdgeOpOverload) + and get_edge_overload_packet(node.target) in self.hoistable_ops + ] + for node in hoistable_nodes: + def_node = node.args[0] + if not isinstance(def_node, torch.fx.Node): + continue + # The def node must have a single user + users = list(def_node.users.keys()) + if len(users) != 1: + continue + + # Get the node args as list + args = list(node.args) + + # If the graph has placeholders, we do not want to hoist above the + # last placeholder. Otherwise we will shrink the live range of the + # def_node considerably, which could lead to reuse of input memory. + def_node = ( + get_placeholders(graph)[-1] + if def_node.op == "placeholder" + else def_node + ) + + # If the node is quantize_per_channel, we need to hoist the scale + # and zero_point tensors as well. + if ( + node.target + == exir_ops.edge.quantized_decomposed.quantize_per_channel.default + ): + scale, zero_point = args[1], args[2] + with graph.inserting_after(def_node): + zero_point_copy = graph.node_copy(zero_point) + scale_copy = graph.node_copy(scale) + args[1], args[2] = scale_copy, zero_point_copy + def_node = zero_point_copy + + # Insert the quant node just after def_node + with graph.inserting_after(def_node): + new_node = graph.call_function( + node.target, args=tuple(args), kwargs=node.kwargs + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + # Eliminate dead code + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.hoist_ops_closer_to_def(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(ExportPass): + """ + A common pattern seen in transformer models. If the consumer of permute + is a view op, swap their order so permute is below view. + Change "permute -> view" to "view -> permute" + This is to optimize a chain of view->permute->view->permute... + so that the chain will be become view->v...->view->permute->p...->permute. + The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and + FuseCascadedViewOps(). + Notice the class name has ViewSqueeze to indicate the View is + functionally the same as a squeeze or unsqueeze. It does not necessarily + mean the view_copy is normalized from squeeze or unsqueeze. + """ + + def __init__(self): + super().__init__() + self.graph_module = None + + # If list1 and list2 are same (same values and in same order) except + # list1 has one more element with value of 1. Return index of the extra 1. + # Otherwise return -1. + def check_if_shapes_differ_in_single_dim_of_size_1(self, list1, list2) -> int: + if len(list1) != len(list2) + 1: + return -1 + for i in range(len(list2)): + if list1[i] != list2[i]: + # Return index of the extra 1 if the remaining parts are the same + if list1[i] == 1 and list2[i:] == list1[i + 1 :]: + return i + else: + return -1 + # If no difference was found, the extra element is at the end + if list1[-1] == 1: + return len(list2) + else: + return -1 + + def insert_nodes( + self, + graph: torch.fx.Graph, + pred: torch.fx.Node, + permute_node: torch.fx.Node, + view_node: torch.fx.Node, + new_view_shape: List, + new_permute_dims: List, + ): + with graph.inserting_after(view_node): + new_view_node = graph.call_function( + view_node.target, # pyre-fixme[6] + args=(pred, new_view_shape), + ) + + with graph.inserting_after(new_view_node): + new_permute_node = graph.call_function( + permute_node.target, # pyre-fixme[6] + args=(new_view_node, new_permute_dims), + ) + new_permute_node.meta = view_node.meta + view_node.replace_all_uses_with(new_permute_node) + + # view_node is user of permute_node, so must erase view_node first + graph.erase_node(view_node) + graph.erase_node(permute_node) + + # flake8: noqa 'PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView.postpone_permute_op' is too complex (13) + def postpone_permute_op(self, graph_module: torch.fx.GraphModule): + packet_to_overload_map = { + exir_ops.edge.aten.permute_copy: "default", + } + graph = graph_module.graph + changed = True + modified = False + # Loop iteratively until no more changes are made + while changed: + changed = False + for permute_node in graph.nodes: + permute_overload_packet = get_overload_packet(permute_node.target) + if permute_overload_packet not in packet_to_overload_map.keys(): + continue + + users = list(permute_node.users.keys()) + # Transform only for pattern permute_copy->view_copy, and + # view_copy op is the only user of permute_copy. + if len(users) == 1 and users[0].target in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + ): + # If the permute_node/view_node was newly added to the + # graph, it may not have the meta["val"] FakeTensor. + # Skip in this case. + if permute_node.meta.get("val") is None: + continue + permute_node_shape = [ + *cast(list, get_shape(graph_module, permute_node)) + ] + permute_dims = permute_node.args[1] + view_node = users[0] + if view_node.meta.get("val") is None: + continue + view_node_shape = [*cast(list, get_shape(graph_module, view_node))] + pred = permute_node.args[0] + if pred.meta.get("val") is None: + continue + pred_shape = [*cast(list, get_shape(graph_module, pred))] + # Handle two cases + # 1. view_node_shape is almost same as permute_node_shape + # except the view_node has one more dim somewhere + # and the extra dim has value of 1. + # 2. view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # 3. view_node_shape is the same as permute_node_shape. + if len(permute_node_shape) + 1 == len(view_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + view_node_shape, permute_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except it has one more dim somewhere + # and the extra dim has value of 1. + new_view_shape = copy.deepcopy(pred_shape) + new_view_shape.insert(index, 1) + new_permute_dims = [ + x + 1 if x >= index else x for x in permute_dims + ] + new_permute_dims.insert(index, index) + self.insert_nodes( + graph, + pred, + permute_node, + view_node, + new_view_shape, + new_permute_dims, + ) + changed = True + modified = True + elif len(view_node_shape) + 1 == len(permute_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + permute_node_shape, view_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + index_to_remove = permute_dims[index] + new_view_shape = copy.deepcopy(pred_shape) + del new_view_shape[index_to_remove] + new_permute_dims = [ + x - 1 if x > index_to_remove else x + for x in permute_dims + ] + del new_permute_dims[index] + self.insert_nodes( + graph, + pred, + permute_node, + view_node, + new_view_shape, + new_permute_dims, + ) + changed = True + modified = True + elif permute_node_shape == view_node_shape: + # view_node_shape is the same as permute_node_shape + # Replace the uses of view_node with permute_node + view_node.replace_all_uses_with(permute_node) + changed = True + modified = True + + graph_module.recompile() + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.graph_module = graph_module + iter_count = 0 + modified = True + + while modified and iter_count <= 3: + modified = self.postpone_permute_op(self.graph_module) + self.graph_module = super().call(self.graph_module).graph_module + iter_count += 1 + + return super().call(self.graph_module) + + +# The following class consolidates functions to reoder ops (i.e., either hoist +# or sink some ops in the graph). +class CadenceReorderOpsInGraph: + passes = [ + # Hoist/sink nodes closer to their SSA def/use + HoistOpsCloserToDefPass, + SinkOpsCloserToUsePass, + # For quantize/dequantize ops, move them above/below their def chain. + # This is a more aggressive optimization than just hoisting/sinking + # nodes closer to their def/use. + AdvanceQuantizeOpAboveDefChainPass, + PostponeDequantizeOpBelowUseChainPass, + # These passes work on branches instead of linear chains to advance + # quantize op beyond their def. + AdvanceQuantizeOpAboveDefInBranchPass, + ] diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py new file mode 100644 index 00000000000..fd51385bcd9 --- /dev/null +++ b/backends/cadence/aot/replace_ops.py @@ -0,0 +1,2111 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +# This file contains all the functions that replace one op with another in the +# graph. The functions replacing ops for models deployed with Jarvis are grouped +# together in class 'ReplaceOpsInGraph'. Some examples of functions in the class are +# 1. functions that replace an ATen op with a custom op that accepts extra arguments +# 2. functions that replace in-place variants of ATen ops with out-of-place version. +# 3. functions that replace an ATen op with another semantically equivalent ATen op. +# 4. functions that concretize optional args. + +import math +from operator import neg +from typing import cast, Dict, Iterable, Sequence, Set, Tuple + +import torch +import torch.fx +from executorch.backends.cadence.aot.compiler_utils import ( + get_shape, + get_tensor_from_attr, + get_transposed_dims, + get_zero_point, + is_node_with_op, + is_quantized_tensor, + quantize_tensor_multiplier, +) +from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + register_cadence_pass, +) +from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass +from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from torch._subclasses import FakeTensor +from torch.fx.node import Argument + +# A map to represent ops that: +# (a) are functionally equivalent wrt. Jarvis; and +# (b) have identical arguments +# An op whose target is 'key' in this dict can be replaced by the functionally euivalent +# op whose target is 'value'. The replacement would just involve changing the op target. +functionally_equivalent_op_targets: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.aten.relu_.default: exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.unsafe_split.Tensor: exir_ops.edge.aten.split_copy.Tensor, +} + + +def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: + """ + Return true if any of the node in the incoming nodes list is a placeholder + or parameter + """ + return any( + is_node_with_op(node, "placeholder") or is_node_with_op(node, "get_attr") + for node in nodes + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): + """ + A where op with a logical_not and a boolean tensor can be replaced + by a where op with flipped inputs and the initial boolean tensor. + """ + + def replace_logical_nop_where_with_where( + self, graph_module: torch.fx.GraphModule + ) -> None: + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in where nodes + if node.target != exir_ops.edge.aten.where.self: + continue + + # If the third arg is not a logical_not, bail. + if node.args[0].target != exir_ops.edge.aten.logical_not.default: + continue + + # Get the third arg node and its input + logical_not_node = node.args[0] + logical_not_input_tensor = ( + logical_not_node.args[0].to_tensor() + if isinstance(logical_not_node.args[0], ProxyValue) + else logical_not_node.args[0] + ) + + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_tensor.meta["spec"].dtype != torch.bool: + continue + + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_node.args[0], node.args[2], node.args[1]), + ) + # Replace all the uses + node.replace_all_uses_with(linear_node) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.replace_logical_nop_where_with_where(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep + """ + Replace _safe_softmax with _softmax + """ + + def call_operator( + self, + op, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != torch.ops.aten._safe_softmax.default: + return super().call_operator(op, args, kwargs, meta) + + # Add False for the half_to_float argument of softmax + softmax_args = list(args) + [False] + + return super().call_operator( + torch.ops.aten._softmax.default, + tuple(softmax_args), + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplacePT2QuantWithCadenceQuantPass(ExportPass): + """ + Replace the pt2 quantization ops with cadence quantization ops. + We do not link kernels to the PT2 quantization ops, so we need to + replace them with cadence ops at all optimization levels. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args, + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplacePT2DequantWithCadenceDequantPass(ExportPass): + """ + Replace the pt2 dequantization ops with cadence dequantization ops. + We do not link kernels to the PT2 quantization ops, so we need to + replace them with cadence ops at all optimization levels. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.cadence.dequantize_per_tensor.default, + args, + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): + """ + When the shape is static, replace squeeze_copy and unsqueeze_copy ops with + view_copy op + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, + # which allows us to cover all overloads. + if get_edge_overload_packet(op) not in { + exir_ops.edge.aten.squeeze_copy, + exir_ops.edge.aten.unsqueeze_copy, + }: + return super().call_operator(op, args, kwargs, meta) + # Get the output tensor shape + out_shape = meta["val"].shape + + # Bail out if any dim is not an int (dynamic shape) + for dim in list(out_shape): + if not isinstance(dim, int): + return super().call_operator(op, args, kwargs, meta) + + # Return a view op with the new shape + view_args = (args[0], list(out_shape)) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceFunctionallyEquivalentOpTargets(ExportPass): + """ + Replace an op with a functionally equivalent op by just switching the op + target, but without incurring any change to the op args. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in functionally_equivalent_op_targets: + return super().call_operator(op, args, kwargs, meta) + return super().call_operator( + functionally_equivalent_op_targets[op], args, kwargs, meta + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceSelectWithViewOpPass(ExportPass): + """ + If the size along the select dim is 1, then the select op can be replaced + by view op. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.select_copy.int: + return super().call_operator(op, args, kwargs, meta) + + # Glean the shape of input and output tensor + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_shape = in_tensor.shape + out_shape = meta["val"].shape + # Get the select dimension + select_dim = args[1] if args[1] >= 0 else args[1] + len(in_shape) + + if in_shape[select_dim] == 1: + # Return a view op with the new shape + view_args = (args[0], list(out_shape)) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta + ) + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceTCopyWithTransposePass(ExportPass): + """ + Replace t_copy with transpose_copy.int. If the input is 1D, the t_copy is + a nop. t_copy is not supported, so this is an opt_level=0 pass. + """ + + def call_operator(self, op, args, kwargs, meta): + if get_edge_overload_packet(op) != exir_ops.edge.aten.t_copy: + return super().call_operator(op, args, kwargs, meta) + + # Get the input tensor shape + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + + # If the input is a 1D tensor, this t_copy is a nop, so return the input + if in_tensor.dim() <= 1: + return args[0] + + assert in_tensor.dim() == 2, "t_copy expects a tensor with <= 2 dimensions" + transpose_args = (args[0], 0, 1) + return super().call_operator( + exir_ops.edge.aten.transpose_copy.int, transpose_args, kwargs, meta + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceMMWithAddMMPass(ExportPass): + """ + This pass replaces mm with addmm by introducing a zero bias. + mm is not supported, so this is an opt_level=0 pass. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.mm.default: + return super().call_operator(op, args, kwargs, meta) + + # The mm op has two args: input, mat2 + assert len(args) == 2 + X, mat2 = args + + # Create a zero bias tensor, and insert it as a graph buffer before the + # current node + mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2 + bias_size = mat2_tensor.size(1) + zero_bias = super().call_operator( + exir_ops.edge.aten.full.default, + ([bias_size], 0.0), + {"dtype": torch.float32}, + meta, + ) + + # Replace mm with addmm + new_args = (zero_bias, X, mat2) + return super().call_operator( + exir_ops.edge.aten.addmm.default, new_args, kwargs, meta + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceAddMMWithLinearPass(ExportPass): + """ + This pass replaces addmm with linear op. + """ + + def __init__(self): + super().__init__() + self.counter = 0 + + def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in admm nodes + if node.target != exir_ops.edge.aten.addmm.default: + continue + + # The addmm op has three concrete args: input, mat1, mat2 + assert len(node.args) >= 3 + (bias, mat1, mat2) = node.args[0:3] + # The other two args are optional scale args + beta = node.kwargs.get("beta", 1.0) + alpha = node.kwargs.get("alpha", 1.0) + + # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert + # it to linear op by multiplying beta to bias, and alpha to mat2.t(). + # However, the following two conditions must hold: + # a. If bias is not a param, then beta must be 1.0 + # b. If mat2 is not a param, then mat2 must be a transpose op. Also, + # the input to the transpose must be a param, or alpha must be 1.0. + fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 + fit_mat2 = is_node_with_op(mat2, "get_attr") + transposed_mat2 = False + if ( + not fit_mat2 + and is_node_with_op(mat2, "call_function") + and mat2.target == exir_ops.edge.aten.transpose_copy.int + ): + mat2, transposed_mat2 = mat2.args[0], True + fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 + + if not fit_bias or not fit_mat2: + continue + + # Multiply bias by beta + if beta != 1.0: + assert is_node_with_op(bias, "get_attr") + bias_tensor = get_tensor_from_attr(graph_module, bias) + assert isinstance(bias_tensor, torch.Tensor) + bias_tensor = beta * bias_tensor + with graph.inserting_before(node): + bias_name = f"_bias_addmm_to_linear_{self.counter}" + graph_module.register_buffer(bias_name, bias_tensor) + bias = graph.get_attr(bias_name) + + # Use associativity of scalar multiplication, and multiply alpha to mat2 + if is_node_with_op(mat2, "get_attr"): + mat2_tensor = get_tensor_from_attr(graph_module, mat2) + assert isinstance(mat2_tensor, torch.Tensor) + mat2_tensor = alpha * mat2_tensor + # transpose mat2 + mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t() + with graph.inserting_before(node): + mat2_name = f"_mat2_addmm_to_linear_{self.counter}" + graph_module.register_buffer(mat2_name, mat2_tensor) + mat2 = graph.get_attr(mat2_name) + + # Construct the linear node + linear_args = (mat1, mat2, bias) + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.linear.default, args=linear_args + ) + linear_node.meta = node.meta + # Replace all the uses of the addmm op with linear op + node.replace_all_uses_with(linear_node) + self.counter += 1 + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.replace_addmm_with_linear(graph_module) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplacePermuteWithTransposePass(ExportPass): + """ + Replace permute op with transpose if the permutation is only along + two dimensions. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.permute_copy.default: + return super().call_operator(op, args, kwargs, meta) + + # Get the old dim and new dim order + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + old_dims = tuple(range(in_tensor.dim())) + new_dims = args[1] + + # Compute the number of positions in which the old and new order differ + diff = [od for od, nd in zip(old_dims, new_dims) if od != nd] + + # If the difference is in two dimensions, we can replace this permute op + # with transpose op. + if len(diff) == 2: + new_args = (args[0], diff[0], diff[1]) + return super().call_operator( + exir_ops.edge.aten.transpose_copy.int, new_args, kwargs, meta + ) + + return ( + args[0] if len(diff) == 0 else super().call_operator(op, args, kwargs, meta) + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass): + """ + Replace optional tensors with concrete tensors. Currently, we + replace the optional bias tensor with a zero tensor. + """ + + def call_operator(self, op, args, kwargs, meta): + if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: + return super().call_operator(op, args, kwargs, meta) + + # Check if the bias is already concrete + assert len(args) == 9 + if args[2] is not None: + return super().call_operator(op, args, kwargs, meta) + + # The bias length is the number of out channels. + out_shape = meta["val"].shape + bias_size = out_shape[1] + # Create a zero bias tensor (bias is not a constant tensor, + # so it needs to be the result of a graph operation). + zero_bias = super().call_operator( + exir_ops.edge.aten.full.default, + ([bias_size], 0.0), + {"dtype": torch.float32}, + meta, + ) + + # Replace bias with zero_bias + args = list(args) + args[2] = zero_bias + args = tuple(args) + + return super().call_operator(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceRepeatWithCatPass(ExportPass): + """ + Replace repeat op as successive cat ops along different dimensions. + repeat is not supported, so this is an opt_level=0 pass. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.repeat.default: + return super().call_operator(op, args, kwargs, meta) + + # Extract the input tensor, and the repeats from the args + in_tensor = args[0] + repeats = args[1] + + # Glean the shapes of input tensor + in_shape = list( + in_tensor.to_tensor().shape + if isinstance(in_tensor, ProxyValue) + else in_tensor.shape + ) + + # If the size of repeats is more than the dimensionality of the tensor, + # the output of repeat will be a higher-dimensional tensor. We reshape + # the input so that it has the same dimensionality as the output tensor. + diff = len(repeats) - len(in_shape) + assert ( + diff >= 0 + ), "Repeat arg malformed: expected a repeat along each dimension of input tensor" + + if diff > 0: + # Extend the input shape with 1's along the higher dimensions + in_shape = ([1] * diff) + in_shape + # Insert a view op that reshapes the input tensor to have same + # dimensionality as the output tensor. + in_tensor = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (in_tensor, in_shape), + kwargs, + meta, + ) + assert len(repeats) == len(in_shape) + + # Repeat op is nothing but successive cat ops along each dimension. + for dim, repeat in reversed(list(enumerate(repeats))): + # We do not need to do anything if repeat factor is 1 + if repeat == 1: + continue + cat_arg = [in_tensor] * repeat + in_tensor = super().call_operator( + exir_ops.edge.aten.cat.default, (cat_arg, dim), kwargs, meta + ) + + return in_tensor + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplacePadWithCatPass(ExportPass): + """ + Replace constant pad nd op that does padding on outer-most dimension + with Cat(left_padding_constant_tensor, X, right_padding_constant_tensor) + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.constant_pad_nd.default: + return super().call_operator(op, args, kwargs, meta) + + assert len(args) >= 2 + input_node, orig_padding = args[:2] + + # if there is no padding, this op will be treated in removal pass. + if not orig_padding: + return super().call_operator(op, args, kwargs, meta) + + value = 0 if len(args) == 2 else args[2] + + arg_shape = input_node.to_tensor().shape + + padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) + assert len(padding) >= 2 + (left_padding_size, right_padding_size) = padding[-2:] + # Replace only if constant_pad_nd is along the innermost padding dimension. + if ( + any(x != 0 for x in padding[0:-2]) + or left_padding_size < 0 + or right_padding_size < 0 + ): + return super().call_operator(op, args, kwargs, meta) + + cat_tensors = [] + dim = len(arg_shape) - len(padding) // 2 + # add left_padding + if left_padding_size > 0: + left_padding_shape = ( + arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :] + ) + left_padding_node = super().call_operator( + torch.ops.aten.full.default, + ( + left_padding_shape, + value, + ), + {"dtype": torch.float32}, + meta, + ) + cat_tensors.append(left_padding_node) + # input_node + cat_tensors.append(input_node) + # right_padding + if right_padding_size > 0: + right_padding_shape = ( + arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :] + ) + right_padding_node = super().call_operator( + torch.ops.aten.full.default, + ( + right_padding_shape, + value, + ), + {"dtype": torch.float32}, + meta, + ) + cat_tensors.append(right_padding_node) + + assert len(cat_tensors) == 1 + (left_padding_size > 0) + ( + right_padding_size > 0 + ) + + new_args = (cat_tensors, dim) + return super().call_operator( + exir_ops.edge.aten.cat.default, + new_args, + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceConstantPadNdWithSlicePass(ExportPass): + """ + Replace constant pad nd op that does padding on outer-most dimension + with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor) + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.constant_pad_nd.default: + return super().call_operator(op, args, kwargs, meta) + + assert len(args) >= 2 + input_node, orig_padding = args[:2] + + # if there is no padding, this op will be treated in removal pass. + if not orig_padding: + return super().call_operator(op, args, kwargs, meta) + + padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) + assert len(padding) >= 2 + (start, diff) = map(neg, padding[-2:]) + # Replace only if constant_pad_nd is along the innermost padding dimension. + if any(x != 0 for x in padding[0:-2]) or start < 0 or diff < 0: + return super().call_operator(op, args, kwargs, meta) + + arg_shape = input_node.to_tensor().shape + dim = len(arg_shape) - len(padding) // 2 + stop = arg_shape[dim] - diff + assert start <= stop + new_args = (input_node, dim, start, stop) + return super().call_operator( + exir_ops.edge.aten.slice.Tensor, + new_args, + kwargs, + meta, + ) + + +# Make that pass runnable standalone at opt level 0. +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAtenConvolutionWithJarvisConvolutionPass(ExportPass): + """ + Replace aten convolution op with jarvis-specific convolution op, since the + aten version is not supported by jarvis. + Also remove convolution stride if the output size along the strided dimension + is 1. We can enable more transformations (e.g., conv -> linear replacement) + for unit-stride convolutions. + """ + + def call_operator(self, op, args, kwargs, meta): + if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: + return super().call_operator(op, args, kwargs, meta) + # There must be 9 total args. + assert len(args) == 9 + + # Unpack the args + ( + in_tensor, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = args + # Currently we only handle conversion to conv1d and conv2d, therefore + # verify that the stride, padding, dilation, and output_padding have + # len <=2. + assert ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 1 + ) or ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 + ), "Can only map convolution to conv1d and conv2d at present" + + target = ( + exir_ops.edge.cadence.transposed_convolution.default + if transposed + else exir_ops.edge.cadence.convolution.default + ) + + if transposed: + # Flip the height and width dimensions of weight, since we apply a + # gather stencil. Also, the first two dimensions of weight must be + # transposed/interchanged. + # If weight is a ProxyValue, new_weight needs to be the output of a + # graph operation (in this case a transpose_copy op) to be an explicit + # ProxyValue as well. If not, the view op can be done directly on the + # tensor. + transposed_weight = ( + super().call_operator( + exir_ops.edge.aten.transpose_copy.int, + ( + weight, + 0, + 1, + ), + kwargs, + meta, + ) + if isinstance(weight, ProxyValue) + else weight.transpose(0, 1) + ) + + flipped_weight = ( + super().call_operator( + torch.ops.aten.flip.default, + ( + transposed_weight, + [-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2], + ), + kwargs, + meta, + ) + if isinstance(transposed_weight, ProxyValue) + else ( + transposed_weight.flip(-1) + if transposed_weight.dim() == 3 + else transposed_weight.flip(-1, -2) + ) + ) + + # From the previous checks, if flipped_weight is a FakeTensor, it has to be + # a constant (if not, it would be a ProxyValue). Mark it as such. + if isinstance(flipped_weight, FakeTensor): + flipped_weight.constant = flipped_weight + new_args = ( + in_tensor, + flipped_weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + False, + ) + else: + # Verify that output_padding is 0. + assert all( + x == 0 for x in output_padding + ), "Cannot handle padded output in convolution" + + # If the innermost dim of output tensor is 1, then the stride + # should be 1. Note that the first dimension of output tensor is + # channel + new_stride = stride.copy() + out_shape = meta["val"].shape + assert out_shape is not None + for i, e in enumerate(out_shape[2:]): + new_stride[i] = 1 if e == 1 else stride[i] + + new_args = ( + in_tensor, + weight, + bias, + new_stride, + padding, + dilation, + groups, + False, + ) + + return super().call_operator(target, new_args, kwargs, meta) + + +# TODO(matthiascremon): this is a fuse op, not a replace op +class ReplaceConvWithChannelLastConv: + """ + Convolution op in pytorch expects NCHW layout for input, weight, and output + tensors. However, if the input and output to the convolution op are originally + in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse + the two permute ops with the convolution op, and call the NHWC layout + convolution op in Jarvis. + """ + + def __init__(self): + self.counter = 0 + self.graph_module = None + + def __call__(self, graph_module: torch.fx.GraphModule): + self.replace_conv_with_nhwc_conv(graph_module) + + def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool: + """ + Return true if the convolution input and output are connected to permute + ops, and the input/output to/from the permute ops is NHWC layout tensor. + """ + # There must only be a single user of the output node (which must be a + # permute/tranpsose op). The input of the convolution must be connected + # to a permute op, and that permute op should have a single user. + conv_inp = node.args[0] + assert isinstance(conv_inp, torch.fx.Node) + if len(node.users) != 1 or len(conv_inp.users) != 1: + return False + + # Get the input and output (permute/transpose) nodes of the convolution + conv_user = list(node.users.keys())[0] + assert isinstance(conv_user, torch.fx.Node) + pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user} + + # Any node in pt_nodes must not be a placeholder. + if contains_placeholder_or_param(pt_nodes): + return False + + # Determine if the convolution is 1d or 2d. The output tensor must be + # 3- or 4-dimensional + out_shape = get_shape(self.graph_module, node) + assert out_shape is not None + out_dims = len(out_shape) + assert out_dims in {3, 4}, "Jarvis only supports conv1d and conv2d" + conv1d = out_dims == 3 + + # Get the possible targets for the nodes in pt_nodes. Since conv1d has + # 3-dimensional input and output tensors, the nodes in pt_nodes could + # be either permute or transpose op. For conv2d, the nodes in pt_nodes + # must be permute ops. + p_target = exir_ops.edge.aten.permute_copy.default + t_target = exir_ops.edge.aten.transpose_copy.int + pt_targets = [p_target] + ([t_target] if conv1d else []) + + # If any node in pt_nodes is not permute op (or tranpose op for conv1d), + # bail. + if any(x.target not in pt_targets for x in pt_nodes): + return False + + # Now we need to determine the dimension permutations: + # If the input had NHWC layout, which was then permuted/transposed + # by a permute/transpose op to NCHW layout, the permutation must be + # [0, 3, 2, 1] (or [0, 2, 1] for conv1d). + # If the output had NCHW layout, and was then permuted to NHWC layout, + # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d). + nhwc_permute_order = { + node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2], + list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1], + } + for x in pt_nodes: + order = ( + x.args[1] + if x.target == p_target + else get_transposed_dims(x, list(range(out_dims))) + ) + if order != nhwc_permute_order[x]: + return False + + return True + + def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule): + self.graph_module = graph_module + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in convolution nodes that have NHWC layout + if node.target not in { + exir_ops.edge.cadence.quantized_conv.default, + exir_ops.edge.cadence.convolution.default, + exir_ops.edge.cadence.quantized_transposed_conv.default, + exir_ops.edge.cadence.transposed_convolution.default, + } or not self.conv_layout_is_nhwc(node): + continue + + # Get the args of convolution op + args = list(node.args) + # The input is connected to a permute/transpose op that converts the + # NHWC layout to NCHW layout. The input of the permute op will become + # this convolution op's input. + in_tp = args[0] + args[0] = in_tp.args[0] + # The weight is in NHWC layout. Permute it to NHWC layout. + weight_tensor = get_tensor_from_attr(graph_module, args[1]) + assert isinstance(weight_tensor, torch.Tensor) + # We cannot directly permute a per-channel quantized tensor. We will + # dequantize it, permute the fp32 tensor, and then requantize the + # permuted tensor. + if ( + is_quantized_tensor(weight_tensor) + and weight_tensor.qscheme() == torch.per_channel_affine + ): + # We have already asserted during quantizing conv op that the + # quantization axis is 0. + dequant_weight = weight_tensor.dequantize() + dequant_weight = ( + dequant_weight.permute([0, 2, 1]) + if dequant_weight.dim() == 3 + else dequant_weight.permute([0, 2, 3, 1]) + ) + weight_tensor = torch.quantize_per_channel( + dequant_weight.contiguous(), + weight_tensor.q_per_channel_scales(), + weight_tensor.q_per_channel_zero_points(), + 0, + weight_tensor.dtype, + ) + else: + weight_tensor = ( + weight_tensor.permute([0, 2, 1]) + if weight_tensor.dim() == 3 + else weight_tensor.permute([0, 2, 3, 1]) + ) + # Make the weight tensor contiguous, since we have permuted it. + weight_tensor = weight_tensor.contiguous() + # Add the permuted weight into the graph, and update the weight in + # args. + with graph.inserting_before(node): + weight_name = f"_weight_nhwc_{self.counter}" + graph_module.register_buffer(weight_name, weight_tensor) + weight = graph.get_attr(weight_name) + args[1] = weight + + # The 'channel_last' arg is True. It is the last arg. + args[-1] = True + # Now update the convolution node args to mark it as NHWC convolution + node.args = tuple(args) + + # Replace all the uses of the permute op connected to the output op + # with this convolution. + out_tp = list(node.users.keys())[0] + out_tp.replace_all_uses_with(node) + node.meta = out_tp.meta + + # Erase the permute ops connected to the input and output of the + # convolution op. + graph.erase_node(in_tp) + graph.erase_node(out_tp) + self.counter += 1 + + graph_module.recompile() + + +# This pass needs to be reworked to be compatible with PT2. It is an optimization +# pass anyway, so move it to opt level 2. +# TODO(matthiascremon): update and improve this pass. +@register_cadence_pass(CadencePassAttribute(opt_level=2)) +class ReplaceConvWithChannelLastConvPass(ExportPass): + """ + Replace the ATen convolution op with custom conv op with NCHW or NHWC layout + input tensors, depending on the presence of permute/transpose ops connected + to the input tensor. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + result = ReplaceAtenConvolutionWithJarvisConvolutionPass()(graph_module) + assert result is not None + ReplaceConvWithChannelLastConv()(result.graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceTrivialConvWithLinear(ExportPass): + """ + In nn.Conv1d, the operand shapes are: + input - [batch, in_channels, in_length] + weight - [out_channels, in_channels, weight_length] + output - [batch, out_channels, out_length] + When in_length == weight_length, out_length = 1. In this scenario, we can + view the input as a tensor shaped [batch, K], and weight as a tensor + shaped [out_channels, K], and replace nn.Conv1d with nn.Linear. This + optimization can be extended to nn.Conv2d as well, where in_length is a 2d + image, and weight_length can be replaced with a 2d filter the same shape as + the image. + """ + + trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.trivial_conv_op_to_linear_op: + return super().call_operator(op, args, kwargs, meta) + + # Parse the necessary args of the convolution node. Both convolution + # and quantized_conv have the same first 8 args. The quantized op has + # extra args holding at least the zero point and scale of input, weight, bias, + # and output tensor. + quantized_op = op == exir_ops.edge.cadence.quantized_conv.default + assert (len(args) == 8 and not quantized_op) or ( + len(args) >= 12 and quantized_op + ), "Inconsistent args for convolution" + (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + + # Glean the shapes of input, weight, and output + in_shape = ( + in_tensor.to_tensor().shape + if isinstance(in_tensor, ProxyValue) + else in_tensor.shape + ) + + weight_shape = ( + weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape + ) + out_shape = meta["val"].shape + assert None not in {in_shape, weight_shape, out_shape} + + # Check the condition under which conv can be replaced by linear: (1) this + # should not be a depthwise convolution; (2) the padding, stride, and dilation + # should be standard; (3) The [channels, height, width] of input must match the + # [channel, kernel_height, kernel_width] of the weight. These conditions would + # ensure that output height and width are 1, and the convolution can be replaced + # by linear. + if ( + groups != 1 + or any(x != 0 for x in padding) + or any(x != 1 for x in stride) + or any(x != 1 for x in dilation) + or (list(in_shape[1:]) != list(weight_shape[1:])) + ): + return super().call_operator(op, args, kwargs, meta) + + # Reshape the weight to [out_channels, in_channels * X] + K = math.prod(weight_shape[1:]) + + # If weight is a ProxyValue, linear_weight needs to be the output of a + # graph operation (in this case a view_copy op) to be an explicit ProxyValue + # as well. If not, the view op can be done directly on the tensor. + linear_weight = ( + super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, + ) + if isinstance(weight, ProxyValue) + else weight.contiguous().view(weight_shape[0], K) + ) + # From the previous check, if linear_weight is a FakeTensor, it has to be + # a constant (if not, it would be a ProxyValue). Mark it as such. + if isinstance(linear_weight, FakeTensor): + linear_weight.constant = linear_weight + + # Reshape the input from 3d to 2d tensor + in_view = super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + in_tensor, + [in_shape[0], K], + ), + kwargs, + meta, + ) + # Create the linear node, which multiplies the 2d input and weight + # tensors, and adds the 1d bias to produce a 2d output. + if quantized_op: + ( + in_zero_point, + weight_zero_point, + bias_scale, + out_scale, + out_zero_point, + ) = args[7:12] + # If the multiplier and shift tensors are provided, use them. + if ( + len(args) >= 14 + and isinstance(args[12], ProxyValue) + and isinstance(args[13], ProxyValue) + ): + out_multiplier = args[12] + out_shift = args[13] + # If not, compute them. + else: + requantize_scale = bias_scale / out_scale + (out_multiplier, out_shift) = quantize_tensor_multiplier( + requantize_scale + ) + linear_args = ( + in_view, + linear_weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + None, + ) + else: + linear_args = (in_view, linear_weight, bias) + + linear_res = super().call_operator( + self.trivial_conv_op_to_linear_op[op], + linear_args, + kwargs, + meta, + ) + # Reshape the output of linear from 2d to 3d tensor + out_res = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (linear_res, list(out_shape)), + kwargs, + meta, + ) + return out_res + + +def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: + """Canonicalize transpose ops so it gets easier to pattern-match and fuse transpose ops.""" + if dim < 0: + # Keep transpose dimensions positive. + dim += len(shape) + return dim + + +class ExportPassWithTransposeHelper(ExportPass): + def transpose_dims( + self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int + ) -> ProxyValue: + """Helper function to transpose dims of a `proxy` with given `meta`.""" + shape = proxy.data.shape + dim0, dim1 = ( + canonicalize_transposed_dim(dim0, shape), + canonicalize_transposed_dim(dim1, shape), + ) + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) + return super().call_operator( + exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ForceChannelLastForConvPass(ExportPassWithTransposeHelper): + def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: + shape = proxy.to_tensor().shape + if len(shape) == 3: + return self.transpose_dims(proxy, meta, 1, -1) + indices = list(range(len(shape))) + permute_indices = [indices[0]] + indices[2:] + [indices[1]] + return super().call_operator( + exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + ) + + def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: + shape = proxy.to_tensor().shape + if len(shape) == 3: + return self.transpose_dims(proxy, meta, 1, -1) + indices = list(range(len(shape))) + permute_indices = [indices[0], indices[-1]] + indices[1:-1] + return super().call_operator( + exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + ) + + def call_operator( + self, + op, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.cadence.convolution.default, + exir_ops.edge.cadence.quantized_conv.default, + }: + return super().call_operator(op, args, kwargs, meta) + + quantized_op = op == exir_ops.edge.cadence.quantized_conv.default + channel_last_arg_index = 14 if quantized_op else 7 + channel_last = ( + args[channel_last_arg_index] + if len(args) > channel_last_arg_index + # Default is false (NCHW). + else False + ) + if channel_last: + return super().call_operator(op, args, kwargs, meta) + + input_proxy = cast(ProxyValue, args[0]) + weight_proxy = cast(ProxyValue, args[1]) + input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) + weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) + + new_args = ( + # Transposed input/weights. + (input_proxy, weight_proxy) + # All other args (bias, quant params, etc) + + tuple(args[2:channel_last_arg_index]) + # Channel last. + + (True,) + ) + output_proxy = super().call_operator(op, new_args, kwargs, meta) + nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) + return nchw_proxy + + +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): + def call_operator( + self, + op, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.slice_copy.Tensor, + }: + return super().call_operator(op, args, kwargs, meta) + dim = cast(int, args[1]) if len(args) > 1 else 0 + output_shape = meta["val"].shape + if dim < 0: + # Keep dim positive. + dim += len(output_shape) + + if dim == 0 or math.prod(output_shape[:dim]) == 1: + # Not needed if dim is already outermost or all dims before it are 1. + return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) + + if op == exir_ops.edge.aten.slice_copy.Tensor: + # Transpose -> slice. + slice_args = ( + self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), + 0, + ) + args[2:] + new_op = super().call_operator(op, slice_args, kwargs, meta) + else: + # (Transpose input0, Transpose input1, ...) -> cat. + cat_in_tensors = [ + self.transpose_dims(t, meta, dim, 0) + for t in cast(list[ProxyValue], args[0]) + ] + new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) + # slice/cat -> transpose. + return self.transpose_dims(new_op, meta, 0, dim) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceConvWithIm2RowAndLinear(ExportPass): + """ + Replace convolution where groups=1 with im2row followed by a linear op. + """ + + # A map from the convolution op to the linear op that it should + # decompose to. + conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.conv_op_to_linear_op: + return super().call_operator(op, args, kwargs, meta) + + # Get the relevant args from convolution node. + quantized_op = op == exir_ops.edge.cadence.quantized_conv.default + assert (len(args) == 8 and not quantized_op) or ( + len(args) >= 12 and quantized_op + ), "Inconsistent args for convolution" + (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + + # We do not replace depthwise convolution with gemm yet. + if groups != 1: + return super().call_operator(op, args, kwargs, meta) + + weight_shape = ( + weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape + ) + # If this is a pointwise convolution, im2col will start dominating the + # runtime. So we call convolution op for this case. + if ( + all(x == 1 for x in weight_shape[2:]) + and all(x == 1 for x in stride) + and all(x == 0 for x in padding) + and all(x == 1 for x in dilation) + ): + return super().call_operator(op, args, kwargs, meta) + + # Get the shapes + out_shape = meta["val"].shape + assert None not in {weight_shape, out_shape} + + # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the + # channel_last layout is specified by the channel_last arg of conv + # op, which is either the last argument (15th) or implicitely False + # if the op is quantized, or the last argument if not. + channel_last = ( + (args[14] if len(args) == 15 else False) if quantized_op else args[-1] + ) + # The weight tensor is [out_channels, in_channels, X] for NCHW layout, + # and [out_channels, X, in_channels] for NHWC layout. Here, X is the + # kernel_width for conv1d, and X = kernel_height * kernel_width for + # conv2d. We extract X as the kernel_size for im2row. + kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) + # If the convolution op was quantized, we need the input tensor's + # zero_point for im2row. Otherwise in_zero_point defaults to a zero + # tensor. + in_zero_point = ( + ( + super().call_operator( + exir_ops.edge.aten.full.default, + ( + [1], + args[7], + ), + {"dtype": torch.int32}, + meta, + ) + if isinstance(in_tensor.to_tensor(), FakeTensor) + else get_zero_point(in_tensor.to_tensor()) + ) + if quantized_op + else torch.tensor(0, dtype=torch.int32) + ) + # im2row expects every kernel parameter to be 2d. So we extend the + # parameters for conv1d by prepending their default values. + stride = ([1] + stride) if len(stride) == 1 else stride + padding = ([0] + padding) if len(padding) == 1 else padding + dilation = ([1] + dilation) if len(dilation) == 1 else dilation + kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size + # Assert that kernel size does not have a 0 + assert 0 not in kernel_size + + # Create an im2row node with the input. This will create a 2d matrix of + # shape [out_height*out_weight, X*in_channels]. X is as defined in the + # comment above. + im2row_args = ( + in_tensor, + kernel_size, + dilation, + padding, + stride, + in_zero_point, + channel_last, + ) + im2row = super().call_operator( + exir_ops.edge.cadence.im2row.default, + im2row_args, + kwargs, + meta, + ) + + # Get the product of the >2 dims of the weight + K = math.prod(weight_shape[1:]) + + # If weight is a ProxyValue, linear_weight needs to be the output of a + # graph operation (in this case a view_copy op) to be an explicit ProxyValue + # as well. If not, the view op can be done directly on the tensor. + linear_weight = ( + super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, + ) + if isinstance(weight, ProxyValue) + else weight.contiguous().view(weight_shape[0], K) + ) + # From the previous check, if linear_weight is a FakeTensor, it has to be + # a constant (if not, it would be a ProxyValue). Mark it as such. + if isinstance(linear_weight, FakeTensor): + linear_weight.constant = linear_weight + + # Create the linear node, which multiplies the 3d input with 2d weight + # tensors with bias addition. The outermost dimension of the input is + # the batch size for linear op. + if quantized_op: + ( + in_zero_point, + weight_zero_point, + bias_scale, + out_scale, + out_zero_point, + ) = args[7:12] + # If the multiplier and shift tensors are provided, use them. + if ( + len(args) >= 14 + and isinstance(args[12], ProxyValue) + and isinstance(args[13], ProxyValue) + ): + out_multiplier = args[12] + out_shift = args[13] + # If not, compute them. + else: + requantize_scale = bias_scale / out_scale + (out_multiplier, out_shift) = quantize_tensor_multiplier( + requantize_scale + ) + linear_args = ( + im2row, + linear_weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + None, + ) + else: + linear_args = (im2row, linear_weight, bias) + linear_res = super().call_operator( + self.conv_op_to_linear_op[op], + linear_args, + kwargs, + meta, + ) + # The output of linear is a 3D tensor. However, the output is in NHWC + # layout by default, because an input vector of size X is multiplied + # with the weight matrix, i.e., column values are contiguous. If the + # channel_last is False, we want to transpose this output. + if not channel_last: + linear_res = super().call_operator( + exir_ops.edge.aten.transpose_copy.int, + (linear_res, 1, 2), + kwargs, + meta, + ) + # And finally, we want to view the 3D output of linear op as 4D tensor + return super().call_operator( + exir_ops.edge.aten.view_copy.default, + (linear_res, list(out_shape)), + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceTransposedConvWithLinearPass(ExportPass): + """ + Replace transposed convolution where groups=1 with transposed_im2row + followed by a linear op. + """ + + # A map from the transposed_convolution op to the linear op that it should + # decompose to. + transposed_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.cadence.transposed_convolution.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_transposed_conv.default: exir_ops.edge.cadence.quantized_linear.default, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.transposed_conv_op_to_linear_op: + return super().call_operator(op, args, kwargs, meta) + + # Get the relevant args from transposed_convolution node. + quantized_op = op == exir_ops.edge.cadence.quantized_transposed_conv.default + assert len(args) == ( + 16 if quantized_op else 9 + ), "Inconsistent args for transposed_convolution" + ( + in_tensor, + weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + ) = args[0:8] + + # We do not replace depthwise transposed_convolution with gemm yet. + if groups != 1: + return super().call_operator(op, args, kwargs, meta) + + # Get the shapes + out_shape = meta["val"].shape + weight_shape = ( + weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape + ) + assert None not in {weight_shape, out_shape} + + # Determine if the transposed_convolution is NCHW or NHWC. The NHWC, + # i.e., the channel_last layout is specified by the channel_last arg + # of transposed_conv op, which is the last argument. + channel_last = args[-1] + # The weight tensor is [out_channels, in_channels, X] for NCHW layout, + # and [out_channels, X, in_channels] for NHWC layout. Here, X is the + # kernel_width for conv1d, and X = kernel_height * kernel_width for + # conv2d. We extract X as the kernel_size for im2row. + kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) + # If the transposed_convolution op was quantized, we need the input tensor's + # zero_point for im2row. Otherwise in_zero_point defaults to a zero + # tensor. + in_zero_point = ( + get_zero_point(in_tensor.to_tensor()) + if quantized_op + else torch.tensor(0, dtype=torch.int32) + ) + # transposed_im2row expects every kernel parameter to be 2d. So we extend the + # parameters for conv1d by prepending their default values. + stride = ([1] + stride) if len(stride) == 1 else stride + padding = ([0] + padding) if len(padding) == 1 else padding + dilation = ([1] + dilation) if len(dilation) == 1 else dilation + output_padding = ( + ([0] + output_padding) if len(output_padding) == 1 else output_padding + ) + kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size + # Assert that kernel size does not have a 0 + assert 0 not in kernel_size + + # Create a transposed_im2row node with the input. This will create a 2d + # matrix of shape [out_height*out_weight, X*in_channels]. X is as + # defined in the comment above. + transposed_im2row_args = ( + in_tensor, + kernel_size, + dilation, + padding, + stride, + output_padding, + in_zero_point, + channel_last, + ) + transposed_im2row = super().call_operator( + exir_ops.edge.cadence.transposed_im2row.default, + transposed_im2row_args, + kwargs, + meta, + ) + # Reshape the weight to [out_channels, in_channels * X] + K = math.prod(weight_shape[1:]) + + # If weight is a ProxyValue, linear_weight needs to be the output of a + # graph operation (in this case a view_copy op) to be an explicit ProxyValue + # as well. If not, the view op can be done directly on the tensor. + linear_weight = ( + super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, + ) + if isinstance(weight, ProxyValue) + else weight.contiguous().view(weight_shape[0], K) + ) + # From the previous check, if linear_weight is a FakeTensor, it has to be + # a constant (if not, it would be a ProxyValue). Mark it as such. + if isinstance(linear_weight, FakeTensor): + linear_weight.constant = linear_weight + + # Create the linear node, which multiplies the 3d input with 2d weight + # tensors with bias addition. The outermost dimension of the input is + # the batch size for linear op. + if quantized_op: + ( + in_zero_point, + weight_zero_point, + bias_scale, + out_scale, + out_zero_point, + ) = args[8:13] + requantize_scale = bias_scale / out_scale + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale) + linear_args = ( + transposed_im2row, + linear_weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + None, + ) + else: + linear_args = (transposed_im2row, linear_weight, bias) + linear_res = super().call_operator( + self.transposed_conv_op_to_linear_op[op], + linear_args, + kwargs, + meta, + ) + # The output of linear is a 3D tensor. However, the output is in NHWC + # layout by default, because an input vector of size X is multiplied + # with the weight matrix, i.e., column values are contiguous. If the + # channel_last is False, we want to transpose this output. + if not channel_last: + linear_res = super().call_operator( + exir_ops.edge.aten.transpose_copy.int, + (linear_res, 1, 2), + kwargs, + meta, + ) + # And finally, we want to view the 3D output of linear op as 4D tensor + return super().call_operator( + exir_ops.edge.aten.view_copy.default, + (linear_res, list(out_shape)), + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceNopTransposeOrPermuteWithViewPass(ExportPass): + """ + If the transpose/permute op does not change the byte order (e.g., + transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced + by view op. + """ + + def call_operator(self, op, args, kwargs, meta): + # Only proceed for transpose or permute op. + if op not in { + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # Get the input tensor and shape + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_shape = in_tensor.shape + # Get the output tensor shape + out_shape = meta["val"].shape + + if op == exir_ops.edge.aten.transpose_copy.int: + # Get the two dims to be transposed + dim0 = args[1] if args[1] >= 0 else in_tensor.dim() + args[1] + dim1 = args[2] if args[2] >= 0 else in_tensor.dim() + args[2] + # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; + # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. + both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 + either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( + in_shape[dim0] == 1 or in_shape[dim1] == 1 + ) + if both_one or either_one_and_consecutive: + new_args = (args[0], list(out_shape)) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta + ) + + elif op == exir_ops.edge.aten.permute_copy.default: + old_dims = list(range(in_tensor.dim())) + new_dims = args[1] + # If the permute does not change anything, return the input as output. + if old_dims == new_dims: + return args[0] + # Get the old dim order, and the permuted dim order for all dims that + # are not 1. + old_order = [ + dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 + ] + new_order = [ + dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 + ] + # If the byte ordering for non-unit dims is unchanged, this is a nop. + if old_order == new_order: + new_args = (args[0], list(out_shape)) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta + ) + + return super().call_operator(op, args, kwargs, meta) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + result = super().call(graph_module) + result = FuseCascadedViewOps()(result.graph_module) + assert result is not None + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceLinearWithFullyConnectedOpPass(ExportPass): + """ + If the input of linear/quantized_linear op is a vector, replace it with + fully_connected op. + """ + + linear_to_fc_op: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.aten.linear.default: exir_ops.edge.cadence.fully_connected.default, + exir_ops.edge.cadence.quantized_linear.default: exir_ops.edge.cadence.quantized_fully_connected.default, + } + + def call_operator(self, op, args, kwargs, meta): + # Only proceed for linear or quantized_linear ops. + if op not in self.linear_to_fc_op: + return super().call_operator(op, args, kwargs, meta) + + # Extract the input tensor + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + leading_dims = math.prod(in_tensor.shape[:-1]) + # If the tensor is not a vector, do nothing. + if leading_dims != 1: + return super().call_operator(op, args, kwargs, meta) + + # If the op is quantized::linear, but per-channel quantized, bail. + if op == exir_ops.edge.cadence.quantized_linear.default: + weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1] + if weight.shape != [1]: + return super().call_operator(op, args, kwargs, meta) + + # Replace the linear with fully connected op + return super().call_operator( + self.linear_to_fc_op[op], + args, + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceScalarWithTensorArgPass(ExportPass): + """ + For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, + replace the scalar arg with Tensor arg. + """ + + scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, + } + + def get_replacement(self, op, args, kwargs, meta): + return super().call_operator( + # Replace with .Tensor variant. + op=self.scalar_to_tensor_ops[op], + args=( + # Tensor arg. + args[0], + # Scalar arg - replace with aten.full tensor. + super().call_operator( + exir_ops.edge.aten.full.default, + args=( + (1,), + args[1], + ), + kwargs={"dtype": args[0].to_tensor().dtype}, + meta=meta, + ), + # Other args. + *args[2:], + ), + kwargs=kwargs, + meta=meta, + ) + + def call_operator(self, op, args, kwargs, meta): + if op not in self.scalar_to_tensor_ops: + return super().call_operator(op, args, kwargs, meta) + + # There must be exactly 2 args (3 for add and sub containing alpha) + assert len(args) == 2 or len(args) == 3 + + # If there are two args, just replace the op. + if len(args) == 2: + return self.get_replacement(op, args, kwargs, meta) + + # In case the op has three args, it must be scalar add/sub op. + if ( + op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} + or "alpha" in kwargs + ): + return super().call_operator(op, args, kwargs, meta) + + return self.get_replacement(op, args, kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceScalarTensorWithFullPass(ExportPass): + """ + aten.scalar_tensor can be replaced by aten.full with a shape of [1]. + scalar_tensor is not supported, so this is an opt_level=0 pass. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.scalar_tensor.default, + torch.ops.aten.scalar_tensor.default, + }: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.aten.full.default, + ( + [1], + args[0], + ), + {"dtype": torch.float32}, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceFullLikeWithFullPass(ExportPass): + """ + aten.full_like can be replaced by aten.full with the shape of the arg tensor. + full_like is not supported, so this is an opt_level=0 pass. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in { + exir_ops.edge.aten.full_like.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # Get the shape of the "like" tensor, and pass that in to the full op. + return super().call_operator( + exir_ops.edge.aten.full.default, + ( + ( + args[0].to_tensor().shape + if isinstance(args[0], ProxyValue) + else args[0].shape + ), + args[1], + ), + {}, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceInfArgInFullWithValuePass(ExportPass): + """ + aten.full allows "-inf" and "inf" as inputs. The profiler cannot + handle that, so replace them with the maximum value of the type. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in { + exir_ops.edge.aten.full.default, + }: + return super().call_operator(op, args, kwargs, meta) + + new_args = list(args) + + if args[1] == float("-inf"): + new_args[1] = torch.finfo(torch.float32).min + elif args[1] == float("inf"): + new_args[1] = torch.finfo(torch.float32).max + + return super().call_operator(op, tuple(new_args), kwargs, meta) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass(ExportPass): + """ + Replace the aten.linalg_vector_norm op with a custom op. + aten.linalg_vector_norm is not supported by Jarvis, so we + need to replace it with native_batch_norm at all optimization levels. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.linalg_vector_norm.default: + return super().call_operator(op, args, kwargs, meta) + + assert ( + len(args) == 1 + ), "aten.linalg_vector_norm should have 1 argument (a tensor), we do not support any custom variants" + + return super().call_operator( + exir_ops.edge.cadence.linalg_vector_norm.default, + args, + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): + """ + Replace ops with single element arguments (size = [1]) with overloads that accept scalar ints/floats. + """ + + # Keep track of which operators and arguments are being replaced. + replaced_scalar_args: dict[ + EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]] + ] = { + exir_ops.edge.cadence.quantized_conv: ( + exir_ops.edge.cadence.quantized_conv.per_tensor, + [8, 9, 12, 13], + ), + exir_ops.edge.cadence.quantized_layer_norm: ( + exir_ops.edge.cadence.quantized_layer_norm.per_tensor, + [1, 2], + ), + exir_ops.edge.cadence.quantized_linear: ( + exir_ops.edge.cadence.quantized_linear.per_tensor, + [4, 5, 6], + ), + exir_ops.edge.cadence.quantized_relu: ( + exir_ops.edge.cadence.quantized_relu.per_tensor, + [1, 3, 4], + ), + } + + def call_operator(self, op, args, kwargs, meta): + op_edge_overload_packet = get_edge_overload_packet(op) + + if op_edge_overload_packet not in self.replaced_scalar_args: + return super().call_operator(op, args, kwargs, meta) + + # Get all the args that need to be replaced. + new_op, args_to_be_replaced = self.replaced_scalar_args[op_edge_overload_packet] + + updated_args = list(args) + for op_arg_index in args_to_be_replaced: + arg = args[op_arg_index] + if not isinstance(arg, ProxyValue): + return super().call_operator(op, args, kwargs, meta) + + if not arg.is_tensor(): + return super().call_operator(op, args, kwargs, meta) + + if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full: + # Only replace if arg generated by a full op. + return super().call_operator(op, args, kwargs, meta) + + if tuple(arg.node.args[0]) != (1,): + # Only replace if the size of the full op is [1]. + return super().call_operator(op, args, kwargs, meta) + + updated_args[op_arg_index] = arg.node.args[1] + + return super().call_operator( + new_op, + tuple(updated_args), + kwargs, + meta, + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAtenAvgPoolWithJarvisAvgPoolPass(ExportPass): + """ + Replace the aten avg_pool op with the jarvis custom avg_pool2d op. + """ + + def call_operator(self, op, args, kwargs, meta): + # Only continue for avg_pool op + if op not in { + exir_ops.edge.aten.avg_pool1d.default, + exir_ops.edge.aten.avg_pool2d.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # Determine if the op is avg_pool1d or avg_pool2d + avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default + # Get the input tensor + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + + # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is + # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. + # stride, padding, ceil_mode, count_include_pad, divisor_override, are + # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout, + # and is False by default. + kernel_size = args[1] + stride = args[2] if len(args) >= 3 else [1, 1] + padding = args[3] if len(args) >= 4 else [0, 0] + ceil_mode = args[4] if len(args) >= 5 else False + count_include_pad = args[5] if len(args) >= 6 else True + divisor_override = args[6] if len(args) >= 7 else None + zero_point = torch.tensor(0, dtype=torch.int32) + + # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d + # tensor. + if avg_pool1d: + in_shape = list(in_tensor.shape) + assert len(in_shape) == 3, "Expected 3d input for avg_pool1d" + in_shape.insert(2, 1) + out_shape = meta["val"].shape + in_view_op = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (in_tensor, in_shape), + kwargs, + meta, + ) + # Extend the kernel_size, stride and padding to 2d + kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size + stride = [1] + stride if len(stride) == 1 else stride + padding = [0] + padding if len(padding) == 1 else padding + + # Create a new avg_pool node with the updated args + new_args = ( + in_view_op if avg_pool1d else args[0], + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + zero_point, + False, + ) + avg_pool2d_op = super().call_operator( + exir_ops.edge.cadence.avg_pool2d.default, + new_args, + kwargs, + meta, + ) + + # If the node was avg_pool1d, we again reshape the 4d output to 3d output + return ( + super().call_operator( + exir_ops.edge.aten.view_copy.default, + (avg_pool2d_op, list(out_shape)), + kwargs, + meta, + ) + if avg_pool1d + else avg_pool2d_op + ) + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceIm2RowWithViewPass(ExportPass): + def can_replace(self, op, args, kwargs, meta) -> bool: + if op != exir_ops.edge.cadence.im2row.default: + return False + + # Check if im2row applies padding. If yes, we cannot replace it with view. + pad = cast(tuple[int, ...], args[3]) + if any(p != 0 for p in pad): + return False + + # Check if im2row has dilation. If yes, we cannot replace it with view. + dilation = cast(tuple[int, ...], args[2]) + if any(d != 1 for d in dilation): + return False + + # im2row works on 3D or 4D tensors. + # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. + output_shape = meta["val"].shape + if math.prod(output_shape[1:-1]) == 1: + return True + + return False + + def call_operator( + self, + op, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.cadence.im2row.default: + return super().call_operator(op, args, kwargs, meta) + + if not self.can_replace(op, args, kwargs, meta): + return super().call_operator(op, args, kwargs, meta) + + output_shape = meta["val"].shape + return super().call_operator( + exir_ops.edge.aten.view_copy.default, + (args[0], tuple(output_shape)), + kwargs, + meta, + ) + + +# This class encapsulates all the functions that replace/switch one op in the +# graph with another. +class CadenceReplaceOpsInGraph: + passes = [ + ReplaceFunctionallyEquivalentOpTargets, + ReplaceTCopyWithTransposePass, + ReplacePermuteWithTransposePass, + ReplaceScalarWithTensorArgPass, + ReplaceConvolutionOptionalArgsWithConcreteArgsPass, + ReplaceMMWithAddMMPass, + ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplaceAddMMWithLinearPass, + RemoveNopSelectOpPass, + ReplaceSelectWithViewOpPass, + ReplaceRepeatWithCatPass, + ReplacePadWithCatPass, + ReplaceConstantPadNdWithSlicePass, + ReplaceConvWithChannelLastConvPass, + ReplaceAtenConvolutionWithJarvisConvolutionPass, + ForceChannelLastForConvPass, + ReplaceTrivialConvWithLinear, + ReplaceConvWithIm2RowAndLinear, + ReplaceTransposedConvWithLinearPass, + # This pass should be after passes that replace conv -> im2row + linear. + ReplaceIm2RowWithViewPass, + MakeSliceAndCatDimOutermostPass, + ReplaceNopTransposeOrPermuteWithViewPass, + ReplaceLinearWithFullyConnectedOpPass, + ReplaceScalarTensorWithFullPass, + ReplaceFullLikeWithFullPass, + ReplaceInfArgInFullWithValuePass, + ReplaceLogicalNotBooleanWhereWithWherePass, + ReplacePT2QuantWithCadenceQuantPass, + ReplacePT2DequantWithCadenceDequantPass, + ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, + ReplaceAtenAvgPoolWithJarvisAvgPoolPass, + ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, + ] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py new file mode 100644 index 00000000000..e0f90ed46fc --- /dev/null +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -0,0 +1,594 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import unittest + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot import compiler +from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2 +from executorch.backends.cadence.aot.fuse_ops import ( + FuseFullThenReshapePass, + FuseMulIntoDequantPass, + FuseQuantDequantToRequantizePass, + FuseTransposeOpPairsPass, +) +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch import nn + + +class TestFusionPassesBase(unittest.TestCase): + def check_op_counts( + self, + graph_module: torch.fx.GraphModule, + expected_op_counts: dict[EdgeOpOverload, int], + ) -> None: + for op, count in expected_op_counts.items(): + self.assertEqual(count_node(graph_module, op), count) + + +class TestFusionPasses(TestFusionPassesBase): + def test_addmm_fusion(self): + class AddmmFeasible1(torch.nn.Module): + def forward(self, x, y, z): + t1 = torch.mm(x, y) + return torch.add(t1, z) + + x = torch.randn(3, 5) + y = torch.randn(5, 6) + z = torch.randn(6) + + graph_module = ( + compiler.export_to_cadence(AddmmFeasible1(), (x, y, z)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + + # Assert that mm and add were fused to addmm + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) + + class AddmmFeasible2(torch.nn.Module): + def forward(self, x, y, z): + t1 = y.view((8, 6)) + t2 = torch.mm(x, t1) + t3 = t2.view((2, 2, 6)) + return torch.add(t3, z) + + x = torch.randn(4, 8) + y = torch.randn(2, 4, 6) + z = torch.randn(6) + + graph_module = ( + compiler.export_to_cadence(AddmmFeasible2(), (x, y, z)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that mm and add were fused to addmm + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) + + # Bias is a singleton value, broadcastable to output of mm + class AddmmFeasible3(torch.nn.Module): + def forward(self, x, y): + t1 = torch.mm(x, y) + return torch.add(t1, torch.ones(1)) + + x = torch.randn(3, 5) + y = torch.randn(5, 6) + + graph_module = ( + compiler.export_to_cadence(AddmmFeasible3(), (x, y)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that mm and add were fused to addmm + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) + + # Bias is not broadcastable to output of mm + class AddmmInfeasible1(torch.nn.Module): + def forward(self, x, y, z): + t1 = y.view((8, 6)) + t2 = torch.mm(x, t1) + t3 = t2.view((2, 2, 6)) + return torch.add(t3, z) + + x = torch.randn(4, 8) + y = torch.randn(2, 4, 6) + z = torch.randn(2, 2, 1) + + graph_module = ( + compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that mm and add were not fused to addmm, since z cannot be + # broadcasted to the out of mm. + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1) + + # The add consuming the output of mm has more than one users. + class AddmmInfeasible2(torch.nn.Module): + def forward(self, x, y, z): + t1 = torch.mm(x, y) + t2 = torch.add(t1, z) + t3 = torch.add(t2, z) + return torch.add(t2, t3) + + x = torch.randn(3, 5) + y = torch.randn(5, 6) + z = torch.randn(6) + + graph_module = ( + compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that mm and add were not fused to addmm, since add has multiple + # users. + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3) + + # TODO(matthiascremon): enable that pass with new flow + @torch.no_grad() + @unittest.expectedFailure + def test_legacy_conv_bn_fusion(self): + class ModelConvBN(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, kernel_size: int): + super().__init__() + self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) + self.bn = nn.BatchNorm1d(out_features) + + def forward(self, x): + y = self.conv1d(x) + return self.bn(y) + + model = ModelConvBN(64, 1, 2) + x = torch.randn(1, 64, 4) + + graph_module = ( + compiler.export_to_executorch(model.eval(), (x,)) + .exported_program() + .exported_program() + .graph_module + ) + # Assert that after running the fusion passes, batchnorm was fused with conv1d + self.assertEqual( + count_node(graph_module, torch.ops.aten.linear.out) + + count_node(graph_module, torch.ops.cadence.convolution.out), + 1, + ) + self.assertEqual( + count_node( + graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out + ), + 0, + ) + + def test_permute_transpose_fusion(self): + class PermuteTranspose(torch.nn.Module): + def forward(self, x): + y = x.permute((0, 2, 4, 1, 3)) + return y.transpose(0, 1) + + x = torch.randn(3, 1, 3, 1, 4) + graph_module = ( + compiler.export_to_cadence(PermuteTranspose(), (x,)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that permute op was fused with transpose op + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1 + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0 + ) + + def test_view_fusion(self): + class ViewFusion(torch.nn.Module): + def forward(self, x): + x = x.view([1, 8, 15]) + x = x.view([1, 1, 120]) + return x.view([1, 12, 10]) + + x = torch.randn(8, 5, 3) + graph_module = ( + compiler.export_to_cadence(ViewFusion(), (x,)) + .exported_program() + .graph_module + ) + graph_module.graph.eliminate_dead_code() + # Assert that only one view op remains + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1 + ) + + def test_force_quant_dequant_fusion(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x = torch.permute(x, [2, 0, 1, 3]) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 4.5, 6, 0, 127, torch.int8 + ) + return x + + inputs = torch.randn(2, 12, 1, 6) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + + graph_module = FuseQuantDequantToRequantizePass( + force_quant_dequant_fusion=True + )(graph_module).graph_module + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that no dequant/quant pair was replaced with requantize. + # quantize -> permute -> dequantize should not be replaced with requantize. + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.cadence.requantize.default: 1, + }, + ) + + def test_no_replace_quant_permute_dequant_with_requantize(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x = torch.permute(x, [2, 0, 1, 3]) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 4.5, 6, 0, 127, torch.int8 + ) + return x + + inputs = torch.randn(2, 12, 1, 6) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that no dequant/quant pair was replaced with requantize. + # quantize -> permute -> dequantize should not be replaced with requantize. + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, + exir_ops.edge.cadence.requantize.default: 0, + }, + ) + + def test_replace_quant_view_dequant_with_requantize(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x = x.view(-1) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 4.5, 6, 0, 127, torch.int8 + ) + return x + + inputs = torch.randn(2, 12, 1, 6) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + graph_module.print_readable() + + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that no dequant/quant pair was replaced with requantize. + # quantize -> permute -> dequantize should not be replaced with requantize. + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.cadence.requantize.default: 1, + }, + ) + + def test_replace_dequant_quant_with_requantize(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x = torch.permute(x, [2, 0, 1, 3]) + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 4.5, 6, 0, 127, torch.int8 + ) + return x + + inputs = torch.randn(2, 12, 1, 6).to(torch.int8) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that dequant -> permute -> quant was replaced with permute -> requantize. + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.cadence.requantize.default: 1, + }, + ) + + def test_replace_dequant_permute_quant_with_requantize(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x = torch.permute(x, [2, 0, 1, 3]) + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 4.5, 6, 0, 127, torch.int8 + ) + return x + + inputs = torch.randn(2, 12, 1, 6).to(torch.int8) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that dequant -> permute -> quant was replaced with permute -> requantize. + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.cadence.requantize.default: 1, + }, + ) + + def test_remove_nop_dequant_quant(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.lin1 = torch.nn.Linear(6, 12, bias=False) + self.lin2 = torch.nn.Linear(12, 24, bias=False) + + def forward(self, x): + x = self.lin1(x) + # redundant dequant+quant will be created around this permute + x = torch.permute(x, [0, 2, 1, 3]) + x = self.lin2(x) + return x + + inputs = torch.randn(2, 12, 1, 6) + model = M() + quantized_model = quantize_pt2(model, (inputs,)) + graph_module = ( + export_to_edge(quantized_model, (inputs,)).exported_program().graph_module + ) + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + self.check_op_counts( + graph_module, + expected_op_counts={ + # Verify that one dequant/quant pair was removed + # Expect 1 quantize ops: 1 input + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + # Expect 1 dequant op at the end (output of second linear) + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, + }, + ) + + def test_fuse_mul_into_dequant(self): + class M(torch.nn.Module): + def forward(self, x): + x0 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 1.5, 0, 0, 255, torch.uint8 + ) + x1 = torch.full([4, 32], 3, dtype=torch.float32) + x2 = x0 * x1 + return x2 + + inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + graph_module = FuseMulIntoDequantPass()(graph_module).graph_module + + # verify that the mul and full ops were removed + self.check_op_counts( + graph_module, + expected_op_counts={ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, + exir_ops.edge.aten.full.default: 0, + exir_ops.edge.aten.mul.Tensor: 0, + }, + ) + + # verify that the dequant scale value was updated correctly + for node in graph_module.graph.nodes: + if ( + node.target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + deq_scale = node.args[1] + self.assertEqual(deq_scale, 4.5) + + def test_fuse_then_transpose_pass(self): + # Create a graph with full -> transpose. + builder = GraphBuilder() + full_node = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=((2, 3), 1) + ) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(full_node, 0, 1), + ) + permute_node = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(transpose_node, (1, 0)), + ) + view_node = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(permute_node, (1, 6, 1)), + ) + builder.output(view_node) + gm = builder.get_graph_module() + self.check_op_counts( + gm, + expected_op_counts={ + exir_ops.edge.aten.full.default: 1, + exir_ops.edge.aten.transpose_copy.int: 1, + exir_ops.edge.aten.permute_copy.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + }, + ) + + # Check that the pass fuses the full with all other ops (transpose, permute, view). + gm_after_pass = FuseFullThenReshapePass()(gm).graph_module + self.check_op_counts( + gm_after_pass, + expected_op_counts={ + exir_ops.edge.aten.full.default: 1, + exir_ops.edge.aten.transpose_copy.int: 0, + exir_ops.edge.aten.permute_copy.default: 0, + exir_ops.edge.aten.view_copy.default: 0, + }, + ) + + +class TestFuseTransposeOpPairsPass(TestFusionPassesBase): + def test_fuse_transpose_pairs(self): + # Create a graph with transpose -> quant -> transpose. + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3)) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 0, 1), + ) + quant_node = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(transpose_node, 1.2, 3, 0, 127, torch.int8), + ) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(quant_node, 0, 1), + ) + builder.output(transpose_node) + gm = builder.get_graph_module() + self.check_op_counts( + gm, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: 2, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + }, + ) + + # Check that the pass fuses the two transpose ops. + gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module + self.check_op_counts( + gm_after_pass, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: 0, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + }, + ) + + def test_no_fusion_for_transpose_pairs(self): + # Create a graph with transpose -> quant -> transpose. + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4)) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 0, 1), + ) + quant_node = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(transpose_node, 1.2, 3, 0, 127, torch.int8), + ) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(quant_node, 1, 2), + ) + builder.output(transpose_node) + gm = builder.get_graph_module() + self.check_op_counts( + gm, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: 2, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + }, + ) + + # No fusion. + gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module + self.check_op_counts( + gm_after_pass, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: 2, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + }, + ) + + def test_fusion_for_forked_transposes(self): + # Create a graph with transpose -> quant -> transpose. + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32)) + transpose_node = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 0, 1), + ) + num_forks = 3 + outputs = [] + for _ in range(num_forks): + quant_node = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(transpose_node, 1.2, 3, 0, 127, torch.int8), + ) + outputs.append( + builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(quant_node, 0, 1), + ) + ) + builder.output(outputs) + gm = builder.get_graph_module() + self.check_op_counts( + gm, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: num_forks + 1, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks, + }, + ) + + # Fuse the all the transpose ops. + gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module + self.check_op_counts( + gm_after_pass, + expected_op_counts={ + exir_ops.edge.aten.transpose_copy.int: 0, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks, + }, + ) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py new file mode 100644 index 00000000000..f465b55c8d6 --- /dev/null +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -0,0 +1,674 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import unittest +from typing import cast, Tuple + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.cadence.aot import compiler +from executorch.backends.cadence.aot.compiler import export_to_edge + +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer +from executorch.backends.cadence.aot.remove_ops import ( + RemoveAliasCopyOpPass, + RemoveCloneOpPass, + RemoveContiguousOpPass, + RemoveDetachCopyPass, + RemoveNopAddOpPass, + RemoveNopExpandOpPass, + RemoveNopLinalgVectorNormOpPass, + RemoveNopMulOpPass, + RemoveNopSelectOpPass, + RemoveNopSliceOrViewOpPass, + RemovePermutesAroundElementwiseOps, + RemoveToOpsPass, + RemoveZeroSizedCatArgsPass, + RemoveZeroSizedConstantPadNd, +) +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized.parameterized import parameterized +from pyre_extensions import none_throws + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.export import export_for_training +from torch.fx.passes.infra.pass_base import PassResult + + +class TestRemoveOpsPasses(unittest.TestCase): + @parameterized.expand( + [ + [(1, 2, 3)], + ] + ) + @torch.no_grad() + def test_remove_to_ops(self, shape: Tuple[int]): + class M(torch.nn.Module): + def forward(self, x: torch.Tensor): + return exir_ops.edge.aten.to(x, dtype=torch.float32) + + model = M() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p = RemoveToOpsPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.to.dtype), + 0, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.to.dtype_layout), + 0, + ) + + @parameterized.expand( + [ + [(7, 6, 5)], + [(7, 6)], + [(7,)], + ] + ) + @torch.no_grad() + def test_remove_nop_add_op_pass(self, shape: Tuple[int]): + class FullX(torch.nn.Module): + def forward(self, t: torch.Tensor): + return torch.add(torch.full(shape, 0), t) + + class FullY(torch.nn.Module): + def forward(self, t: torch.Tensor): + return torch.add(t, torch.full(shape, 0)) + + model = FullX() + t = torch.full(shape, 3) + graph_module = export_to_edge(model, (t,)).exported_program().graph_module + + p = RemoveNopAddOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_module.print_readable() + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), + 0, + ) + + model = FullY() + graph_module = export_to_edge(model, (t,)).exported_program().graph_module + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), + 0, + ) + + @parameterized.expand( + [ + [(7, 6, 5)], + [(7, 6)], + [(7,)], + ] + ) + @torch.no_grad() + def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): + class FullX(torch.nn.Module): + def forward(self, t: torch.Tensor): + return torch.mul(torch.full(shape, 0), t) + + class FullY(torch.nn.Module): + def forward(self, t: torch.Tensor): + return torch.mul(t, torch.full(shape, 0)) + + model = FullX() + t = torch.full(shape, 3) + graph_module = export_to_edge(model, (t,)).exported_program().graph_module + + p = RemoveNopMulOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_module.print_readable() + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), + 0, + ) + + model = FullY() + graph_module = export_to_edge(model, (t,)).exported_program().graph_module + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), + 0, + ) + + @parameterized.expand( + [ + [(1, 2, 3)], + ] + ) + @torch.no_grad() + def test_remove_alias_copy(self, shape: Tuple[int]): + class M(torch.nn.Module): + def forward(self, x: torch.Tensor): + return exir_ops.edge.aten.alias_copy(x) + + model = M() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = RemoveAliasCopyOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default), + 0, + ) + + @parameterized.expand( + [ + [(1, 2, 3)], + ] + ) + @torch.no_grad() + def test_remove_detach_copy(self, shape: Tuple[int]): + # aten::detach is converted to aten::alias_copy after functionalization & decomposition. + class M(torch.nn.Module): + def forward(self, x: torch.Tensor): + return exir_ops.edge.aten.detach_copy(x) + + model = M() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = RemoveDetachCopyPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default), + 0, + ) + + @parameterized.expand( + [ + [(1, 2, 3), (0, 0)], + ] + ) + @torch.no_grad() + def test_remove_zero_sized_constant_pad_nd( + self, shape: Tuple[int], padding: Tuple[int] + ): + # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. + class Padding(torch.nn.Module): + def __init__(self): + super().__init__() + self.padding = padding + + def forward(self, x: torch.Tensor): + return F.pad(x, self.padding) + + model = Padding() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = RemoveZeroSizedConstantPadNd() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), + 0, + ) + + def test_remove_expand(self): + class Expand(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.expand_copy(x, [2, 3, 5]) + + x = torch.ones(2, 3, 5) + p = RemoveNopExpandOpPass() + graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module + graph_module = p(graph_module).graph_module + # Assert that expand op is optimized away, since it is a nop + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0 + ) + + def test_remove_zero_arg_cat(self): + class Cat(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.cat((x, y), 0) + + x = torch.ones(1, 0, 3, 5) + y = torch.ones(2, 0, 3, 5) + graph_module = ( + compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module + ) + # Assert that cat op is optimized away, since it concatenates + # two zero-sized tensors + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) + + def test_remove_single_arg_cat(self): + class Cat(torch.nn.Module): + def forward(self, x, y): + z = torch.ones(0, 5) + # z is an empty tensor, and concatenation of x with z will + # be x. So we can safely eliminate the following cat op. + x1 = torch.ops.aten.cat((x, z)) + x2 = torch.add(x1, 2.4, 3.1) + y1 = torch.add(y, 1, 2) + return torch.add(x2, y1) + + x = torch.ones(3, 5) + y = torch.ones(3, 5) + graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module + new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module + new_graph_module.graph.eliminate_dead_code() + # Assert that x1 is optimized away + self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) + + def test_remove_zero_sized_cat(self): + class Cat(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, tensors): + return torch.cat(tensors, self.dim) + + shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127 + + in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes] + + model = Cat(dim) + graph_module = ( + export_to_edge(model, (in_tensors,)).exported_program().graph_module + ) + new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module + new_graph_module.graph.eliminate_dead_code() + self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) + + def test_remove_clone(self): + class Clone(torch.nn.Module): + def forward(self, x, y): + t1 = x.clone() + t2 = y.clone() + return t1 + t2 + + x = torch.ones(3, 5) + y = torch.ones(3, 5) + graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module + new_graph_module = RemoveCloneOpPass()(graph_module).graph_module + new_graph_module.graph.eliminate_dead_code() + # Assert that t1 and t2 are optimized away + self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0) + + def test_remove_contiguous(self): + class Contiguous(torch.nn.Module): + def forward(self, x, y): + t1 = x.contiguous() + t2 = y.contiguous() + return t1 + t2 + + x = torch.ones(3, 5) + y = torch.ones(3, 5) + graph_module = ( + export_to_edge(Contiguous(), (x, y)).exported_program().graph_module + ) + new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module + new_graph_module.graph.eliminate_dead_code() + # Assert that t1 and t2 are optimized away + self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0) + + @parameterized.expand( + [ + [(3, 5), [3, 5]], + [(1,), [-1]], + ] + ) + @torch.no_grad() + def test_remove_nop_view(self, shape, new_shape): + class View(torch.nn.Module): + def __init__(self, new_shape): + super().__init__() + self.new_shape = new_shape + + def forward(self, x: torch.Tensor): + return x.view(self.new_shape) + + model = View(new_shape) + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p = RemoveNopSliceOrViewOpPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes.graph.eliminate_dead_code() + # Assert that view op was removed + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0 + ) + + def test_remove_nop_slice(self): + class Slice(torch.nn.Module): + def forward(self, x): + return torch.slice_copy(x, dim=0, start=0, step=1) + + x = torch.ones(3, 5) + model = Slice() + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p = RemoveNopSliceOrViewOpPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes.graph.eliminate_dead_code() + # Assert that slice op was removed + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 + ) + + def test_remove_nop_select(self): + class SelectFeasible1(torch.nn.Module): + def forward(self, x): + y = x.select(0, 0) + z = y.view([1, 5, 6]) + return z + + x = torch.ones(1, 5, 6) + graph_module = ( + export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + ) + graph_module = RemoveNopSelectOpPass()(graph_module).graph_module + # Assert that select op was removed + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + ) + + class SelectFeasible2(torch.nn.Module): + def forward(self, x, y): + x = x.select(0, 0) + z = x + y + return z + + x = torch.ones(1, 5, 6) + y = torch.ones(1, 5, 6) + graph_module = ( + export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + ) + graph_module = RemoveNopSelectOpPass()(graph_module).graph_module + # Assert that select op was removed + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + ) + + class SelectFeasible3(torch.nn.Module): + def forward(self, x, y): + x = x.select(0, 0) + z = x * y + return z + + x = torch.ones(1, 5, 6) + y = torch.ones(1, 5, 6) + graph_module = ( + export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + ) + graph_module = RemoveNopSelectOpPass()(graph_module).graph_module + # Assert that select op was removed + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + ) + + class SelectFeasible4(torch.nn.Module): + def forward(self, x, y): + x = x.select(0, 0) + z = x / y + return z + + x = torch.ones(1, 5, 6) + y = torch.ones(1, 5, 6) + graph_module = ( + export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + ) + graph_module = RemoveNopSelectOpPass()(graph_module).graph_module + # Assert that select op was removed + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + ) + + def test_remove_nop_quant_dequant(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.linear = torch.nn.Linear(6, 12, bias=False) + + def forward(self, x): + x = self.linear(x) + return x + + inp = torch.randn(2, 8, 1, 6) + + # Run the standard quant/convert steps, but without fusing + # this leaves two redundant quant/dequant pairs to test with + quantizer = CadenceQuantizer() + model_exp = export_for_training(M(), (inp,)).module() + prepared_model = prepare_pt2e(model_exp, quantizer) + prepared_model(inp) + converted_model = convert_pt2e(prepared_model) + + graph_module = ( + compiler.export_to_cadence( + converted_model, + (inp,), + ) + .exported_program() + .graph_module + ) + + # Expect all quantize ops to be removed by the pass + self.assertEqual( + count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default), + 0, + ) + + # Expect 1 dequantize op for the weights + self.assertEqual( + count_node( + graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 1, + ) + + def test_remove_nop_aten_linalg_vector_norm(self): + class LinalgVectorNorm(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.linalg.vector_norm(x, 2, [0, 1], True) + + model = LinalgVectorNorm() + x = torch.randn([1, 1, 128]) + inputs = (x,) + + graph_module = ( + compiler.export_to_edge( + model, + inputs, + ) + .exported_program() + .graph_module + ) + + graph_module = none_throws( + RemoveNopLinalgVectorNormOpPass()(graph_module) + ).graph_module + + # Expect the linalg_vector_norm op to be removed by the pass + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default) + + count_node( + graph_module, exir_ops.edge.cadence.linalg_vector_norm.default + ), + 0, + ) + + def test_remove_permutes_around_elemwise_ops_add(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(8, 8, 1, bias=False) + + def forward(self, x): + x = self.conv(x) + x = torch.permute(x, [0, 3, 1, 2]) + x = torch.add(x, x) + x = torch.permute(x, [0, 2, 3, 1]) + x = self.conv(x) + return x + + inputs = (torch.randn(1, 8, 4, 4),) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + + def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv2d = nn.Conv2d(8, 8, 1) + + def forward(self, x, y): + x = self.conv2d(x) + y = self.conv2d(y) + x = torch.permute(x, [0, 3, 1, 2]) + y = torch.permute(y, [0, 3, 1, 2]) + z = torch.add(x, y) + z = torch.mean(z, dim=[-1, -3], keepdim=True) + z = torch.permute(z, [0, 2, 3, 1]) + z = self.conv2d(z) + return z + + inputs = (torch.randn(1, 8, 4, 4), torch.randn(1, 8, 4, 4)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + + # verify that mean was updated correctly + mean = [ + n + for n in graph_module.graph.nodes + if n.target == exir_ops.edge.aten.mean.dim + ][0] + self.assertEqual(mean.args[1], [2, 3]) + + def test_remove_permutes_around_elemwise_ops_mul(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + x = torch.slice_copy(x, 0, 0, 1) + x = torch.permute(x, [0, 3, 1, 2]) + y = torch.permute(y, [0, 3, 1, 2]) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 1.5, 0, 0, 255, torch.uint8 + ) + z = x * y + z = torch.ops.quantized_decomposed.quantize_per_tensor( + z, 2.5, 0, 0, 255, torch.uint8 + ) + z = torch.permute(z, [0, 2, 3, 1]) + z = torch.unsqueeze_copy(z, 0) + return z + + inputs = (torch.randn(2, 4, 4, 8), torch.randn(2, 4, 4, 8)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + + def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + x = torch.slice_copy(x, 0, 0, 1) + x = torch.permute(x, [0, 3, 1, 2]) + x = torch.permute(x, [0, 3, 1, 2]) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 1.5, 0, 0, 255, torch.uint8 + ) + y = torch.permute(y, [0, 3, 1, 2]) + y = torch.ops.quantized_decomposed.dequantize_per_tensor( + y, 1.5, 0, 0, 255, torch.uint8 + ) + z = torch.cat((x, y), 1) + z = torch.ops.quantized_decomposed.quantize_per_tensor( + z, 2.5, 0, 0, 255, torch.uint8 + ) + z = torch.permute(z, [0, 2, 3, 1]) + z = torch.permute(z, [0, 2, 3, 1]) + z = torch.unsqueeze_copy(z, 0) + return z + + inputs = (torch.randn(2, 4, 4, 8), torch.randn(1, 8, 4, 4)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Expect 2 permutes to remain, one on input x and one on output z + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 + ) + + # verify that cat was updated correctly + cat = [ + n + for n in graph_module.graph.nodes + if n.target == exir_ops.edge.aten.cat.default + ][0] + self.assertEqual(cat.args[1], 3) + + def test_remove_permutes_around_elemwise_ops_noop(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(8, 8, 1, bias=False) + + def forward(self, x): + x = self.conv(x) + x = torch.permute(x, [0, 2, 3, 1]) + x = torch.add(x, x) + x = torch.permute(x, [0, 3, 1, 2]) + x = self.conv(x) + return x + + inputs = (torch.randn(1, 8, 4, 4),) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Ensure no permutes were removed, since the dimensions don't fit the expected pattern + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 + ) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py new file mode 100644 index 00000000000..fa5f13917a3 --- /dev/null +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -0,0 +1,355 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import unittest + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot.compiler import ( + export_to_edge, + quantize_and_export_to_cadence, +) +from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass +from executorch.backends.cadence.aot.pass_utils import ( + count_node, + get_compute_nodes_in_gm, + nodes_not_adjacent_in_gm, + nodes_not_connected_in_gm, +) +from executorch.backends.cadence.aot.reorder_ops import ( + AdvanceQuantizeOpAboveDefInBranchPass, + PostponeDequantizeOpBelowUseChainPass, + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestReorderPasses(unittest.TestCase): + def test_sink_dequantize(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(6, 12, bias=False) + + def forward(self, x, y): + x1 = self.linear(x) + y1 = self.linear(y) + x2 = torch.ops.aten.abs(x1) + return torch.ops.aten.cat((x2, y1)) + + inputs = (torch.randn(32, 6), torch.randn(32, 6)) + graph_module = ( + quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module + ) + # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it + self.assertTrue( + nodes_not_adjacent_in_gm( + graph_module, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.cat.default, + ), + ) + self.assertTrue( + nodes_not_adjacent_in_gm( + graph_module, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + ), + ) + + def test_advance_branched_quantize(self): + class ReorderOpsBranch(torch.nn.Module): + def forward(self, x): + x = x.view((32, 6)) + x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1) + x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( + x1, 0.1, 10, 0, 255, torch.uint8 + ) + x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1) + x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( + x2, 0.1, 10, 0, 255, torch.uint8 + ) + x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1) + x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( + x3, 0.1, 10, 0, 255, torch.uint8 + ) + x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1) + x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( + x4, 0.2, 4, 0, 255, torch.uint8 + ) + return (x1, x2, x3, x4) + + model = ReorderOpsBranch() + X = torch.randn(64, 3) + graph_module = export_to_edge(model, (X,)).exported_program().graph_module + graph_module = AdvanceQuantizeOpAboveDefInBranchPass()( + graph_module + ).graph_module + graph_module.graph.eliminate_dead_code() + nodes = get_compute_nodes_in_gm(graph_module) + # The quantize op should be hoisted to dominate the branch + self.assertTrue( + nodes[0] == exir_ops.edge.quantized_decomposed.quantize_per_tensor + ) + # There should be 5 quantize ops: the 4 originally present in the model, + # and the one that was hoisted above the slices + self.assertEqual( + count_node( + graph_module, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ), + 5, + ) + # Ensure none of the slice nodes were erroneously removed + self.assertEqual( + count_node( + graph_module, + exir_ops.edge.aten.slice_copy.Tensor, + ), + 4, + ) + # Each of the 4 original quant ops should now be paired with a dequant op + self.assertEqual( + count_node( + graph_module, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ), + 4, + ) + graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + # We expect 3 dequant/quant pairs to be removed because they have matching params, + # leaving a single dequant/quant pair that is then merged into a requantize op + self.assertEqual( + count_node( + graph_module, + exir_ops.edge.cadence.requantize.default, + ), + 1, + ) + + @torch.no_grad() + def test_advance_quantize(self): + class ReorderOpsChain(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(6, 12, bias=False) + + def forward(self, x): + x = x.permute([1, 0, 3, 2]) + x = self.linear(x) + return x + + model = ReorderOpsChain() + X = torch.randn(16, 1, 6, 32) + + graph_module = ( + quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module + ) + # Assert that the quant node is no longer the successor of + # permute node. + self.assertTrue( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + ), + ) + # Assert that permute node is the successor of quant node + self.assertFalse( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.aten.permute_copy.default, + ), + ) + + def test_postpone_dequantize(self): + class ReorderOpsChain(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(6, 12, bias=False) + + def forward(self, x): + x = self.linear(x) + x = x.permute([1, 0, 3, 2]) + return x + + model = ReorderOpsChain() + X = torch.randn(1, 16, 32, 6) + + graph_module = ( + quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module + ) + # Assert that the dequant node is no longer the predecessor of the permute node + self.assertTrue( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.aten.permute_copy.default, + ), + ) + # Assert that dequant node is the successor of permute node + self.assertFalse( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + ), + ) + + def test_postpone_dequantize_branched(self): + class ReorderOpsBranch(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 12, bias=False) + + def forward(self, x): + x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor( + x, 0.1, 10, 0, 255, torch.uint8 + ) + x0 = torch.squeeze(x0, 0) + x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1) + x1 = self.linear(x1) + + x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1) + x2 = self.linear(x2) + + x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1) + x3 = self.linear(x3) + + return (x1, x2, x3) + + model = ReorderOpsBranch() + X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) + graph_module = export_to_edge(model, (X,)).exported_program().graph_module + graph_module = PostponeDequantizeOpBelowUseChainPass()( + graph_module + ).graph_module + graph_module.graph.eliminate_dead_code() + + # Asset that the dequant node was split into 4, one per branch + self.assertEqual( + count_node( + graph_module, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ), + 3, + ) + + # Assert that the dequant node is no longer the predecessor of the squeeze node + self.assertTrue( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.squeeze_copy.dims, + ), + ) + # Assert that dequant node is not predecessor of slice (it should've been moved below slice) + self.assertTrue( + nodes_not_connected_in_gm( + graph_module, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.aten.slice_copy.Tensor, + ), + ) + + # 4d -> permute -> 4d -> view -> 3d + def test_permute3_view4_chains(self): + class PermuteViewChain(torch.nn.Module): + def forward(self, x): + # x is [3, 1, 768] + x = x.view((3, 12, 64)) + # x is [3, 12, 64] + x = x.permute([1, 0, 2]) + # x is [12, 3, 64] + x = x.view((1, 12, 3, 64)) + # x is [1, 12, 3, 64] + x = x.permute([0, 1, 3, 2]) + # x is [1, 12, 64, 3] + return x + + model = PermuteViewChain() + X = torch.randn(3, 1, 768) + graph_module = export_to_edge(model, (X,)).exported_program().graph_module + + # Performing transform + graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + graph_module + ).graph_module + graph_module.graph.eliminate_dead_code() + + # Assert the order becomes view, view, permute, permute + nodes = get_compute_nodes_in_gm(graph_module) + self.assertEqual(len(nodes), 4) + self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy) + self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) + + # 3d -> permute -> 3d -> view -> 4d + def test_permute4_view3_chains(self): + class PermuteViewChain(torch.nn.Module): + def forward(self, x): + # x is [3, 1, 768] + x = x.view((1, 3, 12, 64)) + # x is [1, 3, 12, 64] + x = x.permute([3, 1, 0, 2]) + # x is [64, 3, 1, 12] + x = x.view((64, 3, 12)) + # x is [64, 3, 12] + x = x.permute([2, 1, 0]) + # x is [12, 3, 64] + return x + + model = PermuteViewChain() + X = torch.randn(3, 1, 768) + graph_module = export_to_edge(model, (X,)).exported_program().graph_module + + # Performing transform + graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + graph_module + ).graph_module + graph_module.graph.eliminate_dead_code() + + # Assert the order becomes view, view, permute, permute + nodes = get_compute_nodes_in_gm(graph_module) + self.assertEqual(len(nodes), 4) + self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy) + self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) + + # Negative test case where the transform should not happen. + # permute->4d->view->3d where the view not only removes the dimension whose + # size is 1 (this is ok), but also changes the size of the dimensions (not ok). + def test_permute_view_chains_neg(self): + class PermuteViewChain(torch.nn.Module): + def forward(self, x): + # x is [3, 1, 768] + x = x.view((1, 3, 12, 64)) + # x is [1, 3, 12, 64] + x = x.permute([3, 1, 0, 2]) + # x is [64, 3, 1, 12] + x = x.view((64, 6, 6)) + # x is [64, 6, 6] + x = x.permute([2, 1, 0]) + # x is [6, 6, 64] + return x + + model = PermuteViewChain() + X = torch.randn(3, 1, 768) + graph_module = export_to_edge(model, (X,)).exported_program().graph_module + + # Performing transform (nothing should happen) + graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + graph_module + ).graph_module + graph_module.graph.eliminate_dead_code() + + # Assert the order is still view, permute, view, permute + nodes = get_compute_nodes_in_gm(graph_module) + self.assertEqual(len(nodes), 4) + self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy) + self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy) + self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py new file mode 100644 index 00000000000..fb6f134fd95 --- /dev/null +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -0,0 +1,1683 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import unittest +from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from executorch.backends.cadence.aot import compiler +from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2 +from executorch.backends.cadence.aot.graph_builder import single_op_builder +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.replace_ops import ( + ForceChannelLastForConvPass, + MakeSliceAndCatDimOutermostPass, + ReplaceAddMMWithLinearPass, + ReplaceAtenConvolutionWithJarvisConvolutionPass, + ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, + ReplaceConstantPadNdWithSlicePass, + ReplaceConvolutionOptionalArgsWithConcreteArgsPass, + ReplaceConvWithIm2RowAndLinear, + ReplaceFunctionallyEquivalentOpTargets, + ReplaceIm2RowWithViewPass, + ReplaceLinearWithFullyConnectedOpPass, + ReplaceMMWithAddMMPass, + ReplaceNopTransposeOrPermuteWithViewPass, + ReplacePadWithCatPass, + ReplacePermuteWithTransposePass, + ReplaceRepeatWithCatPass, + ReplaceScalarTensorWithFullPass, + ReplaceScalarWithTensorArgPass, + ReplaceSelectWithViewOpPass, + ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, + ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplaceTCopyWithTransposePass, + ReplaceTransposedConvWithLinearPass, + ReplaceTrivialConvWithLinear, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import dead_code_elimination_pass + +from parameterized.parameterized import parameterized +from torch._ops import OpOverload +from torch.fx.passes.infra.pass_base import PassResult + + +class TestReplaceOpsPasses(unittest.TestCase): + def assertTargetCountEqual( + self, + graph_module: torch.fx.GraphModule, + target: Union[Callable[..., Any], str], + expected_count: int, + ): + """Helper function to check the number of nodes with a given target.""" + actual_count = count_node(graph_module, target) + self.assertEqual( + actual_count, + expected_count, + f"{target} count mismatch for graph {graph_module}", + ) + + def assertTargetCountsEqual( + self, + graph_module: torch.fx.GraphModule, + targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]], + ): + """Helper function to check the number of nodes of all types for a given target.""" + for target, expected_count in targets_and_counts: + self.assertTargetCountEqual(graph_module, target, expected_count) + + @parameterized.expand( + [ + [(3, 5), (0, 0)], + [ + (20, 1, 80), + (0, 0), + ], + ] + ) + @torch.no_grad() + def test_replace_constant_pad_nd_with_slice( + self, shape: Tuple[int], padding: Tuple[int] + ): + # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. + class Padding(torch.nn.Module): + def __init__(self): + super().__init__() + self.padding = padding + + def forward(self, x: torch.Tensor): + return F.pad(x, self.padding) + + model = Padding() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceConstantPadNdWithSlicePass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), + 0, + ) + + @parameterized.expand( + [ + [(7, 5, 6), 1.23], + [(7, 5), 2], + ] + ) + @torch.no_grad() + def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): + class Add(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Scalar(x, other) + + model = Add() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceScalarWithTensorArgPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar), + 0, + ) + + @parameterized.expand( + [ + [(7, 5, 6), 1.23], + [(7, 5), 2], + [(10), 42949], + ] + ) + @torch.no_grad() + def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): + class Sub(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.sub.Scalar(x, other) + + model = Sub() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceScalarWithTensorArgPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar), + 0, + ) + + @parameterized.expand( + [ + [(7, 5, 6), 1.23], + [(7, 5), 2], + [(513), 3], + ] + ) + @torch.no_grad() + def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): + class Mul(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.mul.Scalar(x, other) + + model = Mul() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceScalarWithTensorArgPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar), + 0, + ) + + @parameterized.expand( + [ + [(7, 5, 6), 1.23], + [(7, 5), 2], + ] + ) + @torch.no_grad() + def test_div_replace_scalar_with_tensor_arg( + self, + shape: Tuple[int], + other: float, + ): + class Div(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.div.Scalar(x, other) + + model = Div() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceScalarWithTensorArgPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar), + 0, + ) + + @parameterized.expand( + [ + [(2, 3, 5, 6)], + [(7, 6, 5)], + [(4, 4)], + [(316)], + ] + ) + @torch.no_grad() + def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]): + model = torch.nn.ReLU() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p = ReplaceFunctionallyEquivalentOpTargets() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.relu.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.relu_.default), + 0, + ) + + @parameterized.expand( + [ + # split the only dimension + [(50,), i, 0] + for i in range(2, 7) + ] + + [ + # split the leading dim + [(10, 2, 3), i, 0] + for i in range(2, 7) + ] + + [ + # split the trailing dim + [(3, 3, 6), i, 2] + for i in range(2, 6) + ] + + [ + # split the dim in the middle + [(3, 5, 14, 2, 3), i, 2] + for i in range(2, 7) + ] + ) + @torch.no_grad() + def test_replace_functionally_equivalent_op_targets_unsafe_split( + self, shape: Tuple[int], split_size: int, dim: int + ): + class TensorSplitWithSizes(torch.nn.Module): + def __init__(self, split_size: int, dim: int, op: OpOverload): + super().__init__() + self.split_size = split_size + self.dim = dim + self.op = op + + def forward(self, x: torch.Tensor): + return self.op(x, self.split_size, self.dim) + + x = torch.randn(shape) + model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p = ReplaceFunctionallyEquivalentOpTargets() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default + ), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), + 0, + ) + + @parameterized.expand( + [ + [(16, 32)], + [(1, 240)], + [(4, 16)], + ] + ) + @torch.no_grad() + def test_replace_t_copy_with_transpose(self, shape: Tuple[int]): + class TCopy(torch.nn.Module): + def forward(self, x: torch.Tensor): + return exir_ops.edge.aten.t_copy(x) + + w = torch.randn(shape) + inputs = (w,) + p1 = ReplaceTCopyWithTransposePass() + p2 = ReplacePermuteWithTransposePass() + model = TCopy() + graph_module = export_to_edge(model, inputs).exported_program().graph_module + graph_after_passes = cast( + PassResult, p2(cast(PassResult, p1(graph_module)).graph_module) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.t_copy), + 0, + ) + + @parameterized.expand( + [ + [(1, 8, 33), 8, 16, 3], + [(1, 8, 33), 8, 16, 5, 2], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], + # channel last + [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], + [(1, 33, 8), 8, 16, 5, 2, 0, 1, False, True, True], + ] + ) + @torch.no_grad() + def test_replace_transposed_conv_with_linear( + self, + shape: Tuple[int], + in_channels: int, + out_channels: int, + kernel: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + depthwise: bool = False, + bias: bool = True, + channel_last: bool = False, + ): + class TConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.tconv1d = torch.nn.ConvTranspose1d( + in_channels, + out_channels, + kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels if depthwise else 1, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + if channel_last: + x = x.permute([0, 2, 1]) + x = self.tconv1d(x) + if channel_last: + x = x.permute([0, 2, 1]) + return x + + x = torch.randn(shape) + model = TConv() + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() + p2 = ReplaceTransposedConvWithLinearPass() + graph_after_passes = cast( + PassResult, p2(cast(PassResult, p1(graph_module)).graph_module) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), + 0, + ) + + @parameterized.expand( + [ + [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], + # # depthwise + [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False, False], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False, False], + # channel last (uses a permute op before calling conv1d) + [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], + [(1, 33, 8), 8, 16, 3, 2, 4, 3, True, False, True], + ] + ) + @torch.no_grad() + def test_replace_convolution_optional_args_with_concrete_args( + self, + shape: Tuple[int], + in_channels: int, + out_channels: int, + kernel: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + depthwise: bool = False, + bias: bool = True, + channel_last: bool = False, + ): + class Conv(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d( + in_channels, + out_channels, + kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels if depthwise else 1, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + if channel_last: + x = x.permute([0, 2, 1]) + x = self.conv1d(x) + if channel_last: + x = x.permute([0, 2, 1]) + return x + + x = torch.randn(shape) + model = Conv() + + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.full.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), + 1, + ) + + @parameterized.expand( + [ + [(1, 2, 3), (1, 1)], + [ + (20, 1, 80), + (1, 4), + ], + ] + ) + @torch.no_grad() + def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): + # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. + class Padding(torch.nn.Module): + def __init__(self): + super().__init__() + self.padding = padding + + def forward(self, x: torch.Tensor): + return F.pad(x, self.padding) + + model = Padding() + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplacePadWithCatPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.pad.default), + 0, + ) + + @torch.no_grad() + def test_replace_repeat_with_cat(self): + class Repeat(torch.nn.Module): + def forward(self, x): + x1 = torch.add(x, 2.4, 3.1) + return torch.ops.aten.repeat(x1, [1, 2]) + + x = torch.ones(3, 5) + graph_module = export_to_edge(Repeat(), (x,)).exported_program().graph_module + + p = ReplaceRepeatWithCatPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.repeat.default), + 0, + ) + + @parameterized.expand( + [ + # x, mask + [(1,)], + [(3, 4)], + [(7, 8, 3)], + [(3, 3, 2, 4)], + [(36, 1, 2, 80), (1)], + # tests where mask will be broadcasted + [(36, 1, 2, 80), (1, 1, 2, 1)], + [(36, 2, 8, 4), (36, 1, 1, 4)], + [(36, 2, 8, 4), (2, 1, 4)], + ] + ) + @torch.no_grad() + def test_replace_masked_scalar_tensor_with_full( + self, + shape: Tuple[int], + mask_shape: Union[Tuple[int, ...], None] = None, + ): + class MaskedFill(torch.nn.Module): + def __init__(self, value: float): + super().__init__() + self.value = value + + def forward(self, x: torch.Tensor, mask: torch.Tensor): + return torch.masked_fill(x, mask, self.value) + + x = torch.randn(shape) + mask = torch.randn(mask_shape if mask_shape else shape) > 0 + value = 0.5 * torch.mean(x).item() + model = MaskedFill(value) + graph_module = export_to_edge(model, (x, mask)).exported_program().graph_module + + p = ReplaceScalarTensorWithFullPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.full.default), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.masked_fill), + 0, + ) + + @parameterized.expand( + [ + [(1), 1.5], + [(1), 0.0], + ] + ) + @torch.no_grad() + def test_replace_scalar_tensor_with_full(self, shape: Tuple[int], value: float): + class ScalarTensor(torch.nn.Module): + def __init__(self, shape: Tuple[int], value: float): + super().__init__() + self.shape = shape + self.value = value + + def forward(self, x: torch.Tensor): + return torch.scalar_tensor(value) + + model = ScalarTensor(shape, value) + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + p = ReplaceScalarTensorWithFullPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.full.default), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.scalar_tensor.default), + 0, + ) + + @torch.no_grad() + def test_replace_linear_with_fully_connected(self): + shape, in_features, out_features, bias = (1, 14), 14, 128, False + model = torch.nn.Linear(in_features, out_features, bias=bias) + x = torch.randn(shape) + + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + permute_to_trans_pass = ReplacePermuteWithTransposePass() + mm_to_addmm_pass = ReplaceMMWithAddMMPass() + add_to_linear_pass = ReplaceAddMMWithLinearPass() + linear_to_fullyconnected_pass = ReplaceLinearWithFullyConnectedOpPass() + graph_after_passes = linear_to_fullyconnected_pass( + add_to_linear_pass( + mm_to_addmm_pass( + permute_to_trans_pass(graph_module).graph_module + ).graph_module + ).graph_module + ).graph_module + self.assertIsNotNone(graph_after_passes) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.full.default), + 1, + ) + + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.fully_connected.default + ), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear), + 0, + ) + + @parameterized.expand( + [ + [(4, 16, 256), 256, 512, True], + [(7, 17, 12), 12, 34, False], + ] + ) + @torch.no_grad() + def test_replace_addmm_with_linear( + self, shape: Tuple[int], in_features: int, out_features: int, bias: bool + ): + class AddMM(torch.nn.Module): + def __init__(self, alpha: float = 1, beta: float = 1): + super().__init__() + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): + return torch.addmm( + x, y, z.transpose(1, 0), alpha=self.alpha, beta=self.beta + ) + + # alpha, beta must be 1 to be 1 to enable ReplaceAddMMWithLinearPass + # get_attr will always turn into placeholders and mutable outputs in PT2 + M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 + x = torch.randn(N) + y = torch.randn(M, K) + z = torch.randn(N, K) + + # test addmm + model = AddMM(alpha=alpha, beta=beta) + graph_module = export_to_edge(model, (x, y, z)).exported_program().graph_module + + tp = ReplacePermuteWithTransposePass() + ap = ReplaceAddMMWithLinearPass() + graph_after_passes = cast( + PassResult, ap(cast(PassResult, tp(graph_module)).graph_module) + ).graph_module + self.assertIsNotNone(graph_after_passes) + + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.addmm.default), + 1, + ) + + # Assert that all the aten.addmm nodes are removed. + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), + 0, + ) + + @torch.no_grad() + def test_replace_mm_with_addmm(self): + # The mm ops will be convereted to addmm ops by Jarvis + class MM(torch.nn.Module): + def __init__(self, K, N): + super().__init__() + self.K = K + self.N = N + + def forward(self, y: torch.Tensor, z: torch.Tensor): + return torch.ops.aten.mm(y, z) + + M, K, N = 14, 48, 24 + y = torch.randn(M, K) + z = torch.randn(K, N) + + # test addmm + model = MM(K, N) + graph_module = export_to_edge(model, (y, z)).exported_program().graph_module + + # First, replace the aten.mm with an aten.addmm op + p = ReplaceMMWithAddMMPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertIsNotNone(graph_after_passes) + + # Assert that all the aten.mm nodes are removed. + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), + 1, + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mm), + 0, + ) + + @parameterized.expand( + [ + # shape + [(5, 1, 6, 7)], + [(1)], + [(4, 3, 2)], + # shape, dim to squeeze + [(2, 1), 0], + [(2, 7, 1, 3), 1], + [(2, 1, 3), 2], + ] + ) + @torch.no_grad() + def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): + # The squeeze ops will be convereted to view ops by Jarvis + class Squeeze(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor): + if self.dim is None: + return torch.squeeze(x) + return torch.squeeze(x, self.dim) + + model = Squeeze(dim) + x = torch.randn(shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + # First, replace the aten.squeeze_copy with an aten.view_copy op + p = ReplaceSqueezeAndUnsqueezeWithViewPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertIsNotNone(graph_after_passes) + + # Assert that all the aten.squeeze_copy nodes are removed. + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.aten.squeeze_copy), + 0, + ) + + @parameterized.expand( + [ + # shape, dim to unsqueeze + [(5, 6, 7), 0], + [(5, 6, 7), -1], + [(5, 6, 7), 3], + [(5, 6, 7), 2], + ] + ) + @torch.no_grad() + def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): + class Unsqueeze(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor): + return torch.unsqueeze(x, self.dim) + + # Test that the pass works for all dims. + model = Unsqueeze(dim) + x = torch.randn(5, 6, 7) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + + # First, replace the aten.unsqueeze_copy with an aten.view_copy op + p = ReplaceSqueezeAndUnsqueezeWithViewPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + self.assertIsNotNone(graph_after_passes) + + # Assert that all the aten.unsqueeze_copy nodes are removed. + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.aten.unsqueeze_copy), + 0, + ) + + @torch.no_grad() + def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( + self, + in_features: int = 16, + out_features: int = 16, + ): + # Tensors - these will be inputs to graph. + x = torch.randn([1, in_features]) + + inputs = (x,) + model = torch.nn.Linear(in_features=in_features, out_features=out_features) + quantized_model = quantize_pt2(model, inputs) + + exported_program = export_to_edge(quantized_model, inputs).exported_program() + + # By default, the quantized linear op should have constant scalar attributes. + self.assertTargetCountsEqual( + exported_program.graph_module, + [ + # One quantized linear op. + (exir_ops.edge.cadence.quantized_linear.default, 1), + # No per tensor quantized linear ops. + (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), + # Three aten.full ops. + (exir_ops.edge.aten.full.default, 3), + ], + ) + + # Apply replacement pass. + p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() + graph_after_passes = p(exported_program.graph_module) + self.assertIsNotNone(graph_after_passes) + gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module + + # By default, the quantized linear op should have constant scalar attributes. + self.assertTargetCountsEqual( + gm, + [ + # No default quantized linear op. + (exir_ops.edge.cadence.quantized_linear.default, 0), + # The default quantized linear op will be replaced with quantized_linear.per_tensor. + (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), + # No aten.full ops. + (exir_ops.edge.aten.full.default, 0), + ], + ) + + @torch.no_grad() + def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_args( + self, + in_features: int = 16, + out_features: int = 16, + ): + # Tensors - these will be inputs to graph. + x = torch.randn([1, in_features]) + + inputs = (x,) + model = torch.nn.Linear(in_features=in_features, out_features=out_features) + quantized_model = quantize_pt2(model, inputs) + + exported_program = export_to_edge(quantized_model, inputs).exported_program() + + # By default, the quantized linear op should have constant scalar attributes. + self.assertTargetCountsEqual( + exported_program.graph_module, + [ + # One quantized linear op. + (exir_ops.edge.cadence.quantized_linear.default, 1), + # No per tensor quantized linear ops. + (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), + # Three aten.full ops. + (exir_ops.edge.aten.full.default, 3), + ], + ) + + for node in exported_program.graph_module.graph.nodes: + # Replace the `shape` argument for aten.full op with a tuple. + if node.target == exir_ops.edge.aten.full.default: + node.args = (tuple(node.args[0]), node.args[1]) + + # Apply replacement pass. + p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() + graph_after_passes = p(exported_program.graph_module) + self.assertIsNotNone(graph_after_passes) + gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module + + # By default, the quantized linear op should have constant scalar attributes. + self.assertTargetCountsEqual( + gm, + [ + # No default quantized linear op. + (exir_ops.edge.cadence.quantized_linear.default, 0), + # The default quantized linear op will be replaced with quantized_linear.per_tensor. + (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), + # No aten.full ops. + (exir_ops.edge.aten.full.default, 0), + ], + ) + + @torch.no_grad() + def test_replace_conv1d_with_linear(self): + class Conv(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, kernel_size: int): + super().__init__() + self.conv1d = torch.nn.Conv1d(in_features, out_features, kernel_size) + + def forward(self, x): + return self.conv1d(x) + + model_conv1d = Conv(96, 192, 7) + x = torch.randn(1, 96, 7) + graph_module = ( + export_to_edge(model_conv1d, (x,)).exported_program().graph_module + ) + + # First, replace the aten convolution with a cadence.convolution op + p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() + temp_graph = p1(graph_module).graph_module + self.assertIsNotNone(temp_graph) + + p2 = ReplaceTrivialConvWithLinear() + graph_after_passes = p2(temp_graph).graph_module + + # Assert that conv1d is trivially converted to linear + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default) + + count_node( + graph_after_passes, exir_ops.edge.cadence.fully_connected.default + ), + 1, + ) + + @torch.no_grad() + def test_replace_conv2d_with_linear(self): + class Conv(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, kernel_size: int): + super().__init__() + self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) + + def forward(self, x): + return self.conv2d(x) + + model_conv2d = Conv(96, 192, 7) + x = torch.randn(1, 96, 7, 7) + graph_module = ( + export_to_edge(model_conv2d, (x,)).exported_program().graph_module + ) + + # First, replace the aten convolution with a cadence.convolution op + p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() + temp_graph = p1(graph_module).graph_module + self.assertIsNotNone(temp_graph) + + p2 = ReplaceTrivialConvWithLinear() + graph_after_passes = p2(temp_graph).graph_module + + # Assert that conv2d is trivially converted to linear + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default) + + count_node( + graph_after_passes, exir_ops.edge.cadence.fully_connected.default + ), + 1, + ) + + @torch.no_grad() + def test_replace_conv2d_with_im2row_and_linear(self): + class Conv(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, kernel_size: int): + super().__init__() + self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) + + def forward(self, x): + return self.conv2d(x) + + model_conv2d = Conv(96, 192, 7) + x = torch.randn(1, 96, 47, 37) + graph_module = ( + compiler.export_to_cadence(model_conv2d, (x,)) + .exported_program() + .graph_module + ) + + p = ReplaceConvWithIm2RowAndLinear() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that the convolution is converted to im2row + linear + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 1 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 + ) + + @parameterized.expand( + [ + [(3, 1, 5), 1, 0], + [(3, 4, 1), 2, -1], + ] + ) + @torch.no_grad() + def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int): + class Select(torch.nn.Module): + def forward(self, x): + return x.select(dim, index) + + x = torch.randn(shape) + graph_module = export_to_edge(Select(), (x,)).exported_program().graph_module + + p = ReplaceSelectWithViewOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that select op was replaced with view op + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 + ) + + @parameterized.expand( + [ + [(2, 1, 3, 1), 1, 3, torch.float32], + [(2, 1, 5), 1, 0, torch.int64], + [(3, 1, 5), 0, 1, torch.int64], + ] + ) + @torch.no_grad() + def test_replace_nop_transpose_with_view( + self, + shape: Tuple[int], + dim0: int, + dim1: int, + dtype: torch.dtype = torch.float32, + ): + class Transpose(torch.nn.Module): + def forward(self, x): + return x.transpose(dim0, dim1) + + _max_value = 127 + x = (torch.rand(shape) * _max_value).to(dtype=dtype) + graph_module = export_to_edge(Transpose(), (x,)).exported_program().graph_module + + p = ReplaceNopTransposeOrPermuteWithViewPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that transpose op was removed, and a view op was placed instead + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 + ) + + @parameterized.expand( + [ + # permutations that can be replaced by view + [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3), torch.float32], + [(1, 3, 4), (1, 2, 0), torch.float32], + ] + ) + @torch.no_grad() + def test_replace_nop_permute_with_view(self, input_shape, dims, dtype): + class Permute(torch.nn.Module): + def forward(self, x): + return torch.permute(x, dims) + + x = torch.randn(input_shape).to(dtype=dtype) + graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module + + p = ReplaceNopTransposeOrPermuteWithViewPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that permute op was removed, and a view op was placed instead + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 + ) + + @parameterized.expand( + [ + # permutations replaced by transpose + [(3, 4), [1, 0], torch.float32], + [(3, 4, 6), (0, 2, 1), torch.float32], + ] + ) + @torch.no_grad() + def test_replace_permute_with_transpose(self, input_shape, dims, dtype): + class Permute(torch.nn.Module): + def forward(self, x): + return torch.permute(x, dims) + + x = torch.randn(input_shape).to(dtype=dtype) + graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module + + p = ReplacePermuteWithTransposePass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that permute op was replaced by a transpose op + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1 + ) + + @parameterized.expand( + [ + # permutations replaced by transpose + [(3, 4), [0, 1], torch.float32], + ] + ) + @torch.no_grad() + def test_replace_permute_with_transpose_nop(self, input_shape, dims, dtype): + class Permute(torch.nn.Module): + def forward(self, x): + return torch.permute(x, dims) + + x = torch.randn(input_shape).to(dtype=dtype) + graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module + + p = ReplacePermuteWithTransposePass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that permute op was replaced by a transpose op + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 + ) + + def test_replace_aten_linalg_vector_norm_with_cadence_linalg_vector_norm(self): + class LinalgVectorNorm(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.linalg.vector_norm(x) + + x = torch.randn(32) + + graph_module = ( + export_to_edge(LinalgVectorNorm(), (x,)).exported_program().graph_module + ) + + p = ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that aten.linalg_vector_norm op was replaced by a + # cadence.linalg_vector_norm op + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.linalg_vector_norm.default, + ), + 0, + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.linalg_vector_norm.default + ), + 1, + ) + + +class TestReplaceIm2rowWithViewPass(unittest.TestCase): + def test_no_replacement_for_conv(self): + # Create a graph with a single im2row node. + x = torch.randn(1, 3, 224, 224) + pad_value = torch.randn(1) + channels_last = False + gm = single_op_builder( + placeholders=(x, pad_value), + op=exir_ops.edge.cadence.im2row.default, + args=(x, (2, 2), (1, 1), (0, 0), (1, 1), pad_value, channels_last), + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + + # Apply replacement pass. + p = ReplaceIm2RowWithViewPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 + ) + + def test_no_replace_for_dilation(self): + # Create a graph with a single im2row node. + x = torch.randn(1, 3, 5, 7) + pad_value = torch.randn(1) + channels_last = False + gm = single_op_builder( + placeholders=(x, pad_value), + op=exir_ops.edge.cadence.im2row.default, + args=(x, (3, 4), (2, 2), (0, 0), (1, 1), pad_value, channels_last), + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + + # Apply replacement pass. + p = ReplaceIm2RowWithViewPass() + gm_after_replacement = p.call(gm).graph_module + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 + ) + + def test_replace_linear_like_conv(self): + # Create a graph with a single im2row node. + in_h, in_w = 13, 15 + x = torch.randn(1, 3, in_h, in_w) + pad_value = torch.randn(1) + channels_last = False + gm = single_op_builder( + placeholders=(x, pad_value), + op=exir_ops.edge.cadence.im2row.default, + args=(x, (in_h, in_w), (1, 1), (0, 0), (1, 1), pad_value, channels_last), + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + + # Apply replacement pass. + p = ReplaceIm2RowWithViewPass() + gm_after_replacement = p.call(gm).graph_module + # In this test, the kernel width/height is the same as the input width/height. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0 + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 1 + ) + + +class TestForceChannelLastForConvPass(unittest.TestCase): + def create_conv1d_graphmodule( + self, channels_last: Optional[bool] = None + ) -> torch.fx.GraphModule: + """Helper to create a convolution node. + + convolution( + Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding," + int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" + """ + if channels_last: + x = torch.randn(1, 224, 3) + w = torch.randn(16, 16, 3) + else: + x = torch.randn(1, 3, 224) + w = torch.randn(16, 3, 16) + b = torch.randn(16) + args = (x, w, b, (2, 2), (1, 1), (0, 0), 1) + if channels_last is not None: + args = args + (channels_last,) + return single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.convolution.default, + args=args, + ) + + def test_conv1d_default_channel_last(self): + # Create a graph with a single convolution node. + # Check if graph module is valid by running exportpass on it. + gm = self.create_conv1d_graphmodule() + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), + 1, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), + # Two transposes are added, one for the input and one for the output. + 3, + ) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.convolution.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 8, f"{node=}") + self.assertTrue(node.args[7]) + + def test_conv1d_no_transpose_if_already_channel_last(self): + gm = self.create_conv1d_graphmodule(channels_last=True) + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), + 1, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), + 0, + ) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.convolution.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 8, f"{node=}") + self.assertTrue(node.args[7]) + + def create_convolution_graph_module( + self, channels_last: Optional[bool] = None + ) -> torch.fx.GraphModule: + """Helper to create a convolution node. + + convolution( + Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding," + int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" + """ + if channels_last: + x = torch.randn(1, 224, 224, 3) + w = torch.randn(16, 16, 16, 3) + else: + x = torch.randn(1, 3, 224, 224) + w = torch.randn(16, 3, 16, 16) + b = torch.randn(16) + args = (x, w, b, (2, 2), (1, 1), (0, 0), 1) + if channels_last is not None: + args = args + (channels_last,) + return single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.convolution.default, + args=args, + ) + + def test_convolution_default_channel_last(self): + # Create a graph with a single convolution node. + # Check if graph module is valid by running exportpass on it. + gm = self.create_convolution_graph_module() + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), + 1, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + # Three permutes are added, two for the input/weights and one for the output. + 3, + ) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.convolution.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 8, f"{node=}") + self.assertTrue(node.args[7]) + + def test_no_transpose_if_already_channel_last(self): + gm = self.create_convolution_graph_module(channels_last=True) + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), + 1, + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 0, + ) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.convolution.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 8, f"{node=}") + self.assertTrue(node.args[7]) + + def create_quantized_convolution_graph_module( + self, channels_last: Optional[bool] = None + ) -> torch.fx.GraphModule: + """Helper to create a quantized conv node. + + quantized_conv( + Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, + int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, + Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, + Tensor out_shift, bool channel_last=False) -> (Tensor Z)" + """ + if channels_last: + x = torch.randn(1, 224, 56, 3) + w = torch.randn(16, 16, 16, 3) + else: + x = torch.randn(1, 3, 224, 56) + w = torch.randn(16, 3, 16, 16) + b = torch.randn(16) + stride = (2, 2) + padding = (0, 0) + dilation = (1, 1) + groups = 1 + input_zero_point = 0 + w_zero_point = torch.randn(1) + b_scale = torch.randn(1) + out_scale = 1 + out_zero_point = 0 + out_multiplier = torch.randn(1) + out_shift = torch.randn(1) + args = ( + x, + w, + b, + stride, + padding, + dilation, + groups, + input_zero_point, + w_zero_point, + b_scale, + out_scale, + out_zero_point, + out_multiplier, + out_shift, + ) + if channels_last is not None: + args = args + (channels_last,) + return single_op_builder( + placeholders=(x, w, b, w_zero_point, b_scale, out_multiplier, out_shift), + op=exir_ops.edge.cadence.quantized_conv.default, + args=args, + ) + + def test_quantized_convolution_default_channel_last(self): + # Create a graph with a single convolution node. + gm = self.create_quantized_convolution_graph_module() + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node( + gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default + ), + 1, + ) + # Three permutes are added, two for the input/weights and one for the output. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 3, + ) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.quantized_conv.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 15, f"{node=}") + self.assertTrue(node.args[14]) + + def test_no_transpose_if_already_quantized_conv_channel_last(self): + # Create a graph with a single im2row node. + gm = self.create_quantized_convolution_graph_module(channels_last=True) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1 + ) + + # Apply replacement pass. + p = ForceChannelLastForConvPass() + gm_after_replacement = p.call(gm).graph_module + # Check that no replacement was made. + self.assertEqual( + count_node( + gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default + ), + 1, + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.quantized_conv.default: + continue + # Check that the channel_last argument is set to True. + self.assertEqual(len(node.args), 15, f"{node=}") + self.assertTrue(node.args[14]) + + +class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase): + def create_slice_graph( + self, + input_shape: Sequence[int], + slice_dim: int, + slice_begin: Optional[int] = None, + slice_end: Optional[int] = None, + ) -> torch.fx.GraphModule: + x = torch.randn(*input_shape) + return single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, slice_dim, slice_begin, slice_end), + ) + + def test_slice_no_transpose_if_already_outermost(self): + # Create a graph with a single slice node. + gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that no transpose ops were added. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 0, + ) + + def test_slice_no_transpose_if_outermost_dimensions_are_one(self): + # Create a graph with a single slice node on second outermost dimension. + gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that no transpose ops were added. The slice is on the second + # outermost dimension, but the outermost dimension is already 1. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 0, + ) + + def test_slice_insert_transpose(self): + # Create a graph with a single slice node. + gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that there are two transpose ops added. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 2, + ) + + def create_cat_graph( + self, + input_shapes: Sequence[Sequence[int]], + cat_dim: int = 0, + ) -> torch.fx.GraphModule: + input_tensors = tuple(torch.randn(s) for s in input_shapes) + return single_op_builder( + placeholders=input_tensors, + op=exir_ops.edge.aten.cat.default, + args=(input_tensors, cat_dim), + ) + + def test_cat_no_transpose_if_already_outermost(self): + # Create a graph with a single slice node on second outermost dimension. + gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that no transpose ops were added. The slice is on the second + # outermost dimension, but the outermost dimension is already 1. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 0, + ) + + def test_cat_no_transpose_if_outermost_dimensions_are_one(self): + # Create a graph with a single slice node on second outermost dimension. + gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that no transpose ops were added. The slice is on the second + # outermost dimension, but the outermost dimension is already 1. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 0, + ) + + def test_cat_insert_transpose(self): + # Create a graph with a single slice node on second outermost dimension. + gm = self.create_cat_graph( + input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + + # Apply replacement pass. + gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + + # Assert that transpose ops were added to make cat on outermost dimension. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + 3, + ) diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py new file mode 100644 index 00000000000..347a48b9299 --- /dev/null +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -0,0 +1,108 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import unittest +from typing import cast, Optional, Tuple + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot.compiler import export_to_edge +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized.parameterized import parameterized +from torch.fx.passes.infra.pass_base import PassResult + + +class TestSimplifyOpsPasses(unittest.TestCase): + @parameterized.expand( + [ + [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], + ] + ) + @torch.no_grad() + def test_simplify_slice_scatter_op( + self, + in_shape: Tuple[int], + src_shape: Tuple[int], + dim: int, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, + ): + class SliceScatter(torch.nn.Module): + def __init__( + self, dim: int, start: Optional[int], end: Optional[int], step: int + ): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.slice_scatter( + x, y, self.dim, self.start, self.end, self.step + ) + + model = SliceScatter(dim, start, end, step) + x = torch.randn(in_shape) + y = torch.randn(src_shape) + graph_module = export_to_edge(model, (x, y)).exported_program().graph_module + + p = SimplifySliceOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0 + ) + + @parameterized.expand( + [ + [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], + ] + ) + @torch.no_grad() + def test_simplify_slice_op( + self, + in_shape: Tuple[int], + src_shape: Tuple[int], + dim: int, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, + ): + class SliceCopy(torch.nn.Module): + def __init__( + self, dim: int, start: Optional[int], end: Optional[int], step: int + ): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.slice_copy( + x, dim=self.dim, start=self.start, end=self.end, step=self.step + ) + + # Create a model with single slice copy op. + model = SliceCopy(dim, start, end, step) + x = torch.randn(in_shape) + graph_module = export_to_edge(model, (x,)).exported_program().graph_module + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + + p = SimplifySliceOpPass() + + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1 + ) diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index d0d77bbfb60..e8b64ef5671 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -124,29 +124,29 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]: # Print the ops and how many times they occur multiple graph modules: -# from export, from to_edge, and from Jarvis. Print the available +# from export, from to_edge, and from final. Print the available # implementations for each op, and error out if the op is not supported. def print_ops_info( to_edge_gm: torch.fx.GraphModule, - jarvis_gm: torch.fx.GraphModule, + final_gm: torch.fx.GraphModule, ) -> None: to_edge_ops_count = get_ops_count(to_edge_gm) - jarvis_ops_count = get_ops_count(jarvis_gm) + final_ops_count = get_ops_count(final_gm) removed_ops = [] # Get the counts of the ops that are removed from the final graph for k in to_edge_ops_count: - if k not in jarvis_ops_count: + if k not in final_ops_count: removed_ops.append(k) # Create a dict of ops and their counts to pass to tabulate ops_count = [ [ op, - jarvis_ops_count[op], + final_ops_count[op], to_edge_ops_count[op] if op in to_edge_ops_count else 0, ] - for op in jarvis_ops_count + for op in final_ops_count ] sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True) @@ -166,7 +166,7 @@ def print_ops_info( sorted_ops_count, headers=[ "Final Operators ", # one character longer than the longest op name - "Jarvis (Final) Graph", + "Final Graph", "To_edge Graph", "Export Graph", ], @@ -181,7 +181,7 @@ def print_ops_info( removed_ops_count, headers=[ "Deleted Operators ", # one character longer than the longest op name - "Jarvis (Final) Graph", + "Final Graph", "To_edge Graph", "Export Graph", ], diff --git a/backends/cadence/runtime/runtime.py b/backends/cadence/runtime/runtime.py index 33bb20719c8..bf2932d9c79 100644 --- a/backends/cadence/runtime/runtime.py +++ b/backends/cadence/runtime/runtime.py @@ -28,7 +28,7 @@ from torch.utils._pytree import TreeSpec -class JarvisETDump: +class CadenceETDump: def __init__(self, output_dir: str) -> None: self.tensor_dump_dir: str = os.path.join(output_dir, "tensors") self.etdump_path: str = os.path.join(output_dir, "etdump.etdp") @@ -64,28 +64,26 @@ def get_outputs(self, log_to_stdout: bool = False) -> Tuple[torch.Tensor]: for event_block in self.et_inspector.event_blocks if event_block.name == "Execute" ] - logging.debug(f"[Jarvis][ETdump] output: {output}") + logging.debug(f"[ETdump] output: {output}") return output[0] def print_event_block(self) -> None: - logging.debug("[Jarvis][ETdump] data tabular:") + logging.debug("[ETdump] data tabular:") if logging.getLogger().level <= logging.DEBUG: self.et_inspector.print_data_tabular() def print_event_data(self) -> None: - logging.debug("[Jarvis][ETdump] event data ") + logging.debug("[ETdump] event data ") for event_block in self.et_inspector.event_blocks: for event in event_block.events: logging.debug(event) def dump_intermediate_tensors(self) -> None: if self.etrecord_path is None: - logging.info("[Jarvis][ETdump] Intermediate tensors not available") + logging.info("[ETdump] Intermediate tensors not available") return - logging.info( - f"[Jarvis][ETdump] Dumping intermediate tensors to {self.tensor_dump_dir}" - ) + logging.info(f"[ETdump] Dumping intermediate tensors to {self.tensor_dump_dir}") os.makedirs(self.tensor_dump_dir, exist_ok=True) exec_blocks = [ eb for eb in self.et_inspector.event_blocks if eb.name == "Execute" @@ -153,13 +151,13 @@ def run( if working_dir is None: working_dir = tempfile.mkdtemp(dir="/tmp") - # initialize Jarvis e2e Executor with executorch_cfg. + # initialize e2e Executor with executorch_cfg. executor = Executor(working_dir) # run Executor executor() - etdump = JarvisETDump(output_dir=working_dir) + etdump = CadenceETDump(output_dir=working_dir) outputs = etdump.get_outputs() assert isinstance(out_spec, TreeSpec)