diff --git a/backends/cadence/aot/compiler_utils.py b/backends/cadence/aot/compiler_utils.py index ee52ea71d81..cabfb120341 100644 --- a/backends/cadence/aot/compiler_utils.py +++ b/backends/cadence/aot/compiler_utils.py @@ -109,12 +109,12 @@ def get_cascaded_ops( return nodes -# Capture the effect of transpose op on incoming dimension order -def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: +def get_transposed_dims( + node: torch.fx.Node, dims: Optional[List[int]] = None +) -> List[int]: """ - Given a transpose node, and the incoming dimension ordering of the input - tensor to the transpose node, return the net effect of transpose op on the - dimension order. + Applies the transposition as given by node onto the dimensions given in input + e.g (1, 2) on [a, b, c, d] would return [a, c, b, d] """ assert node.target == exir_ops.edge.aten.transpose_copy.int # Assert that the dims is not empty @@ -127,28 +127,22 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: assert isinstance(transpose_dims1, int) dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len - # Perform transpose on dimmension ordering (dims) - dims[dim0], dims[dim1] = dims[dim1], dims[dim0] - return dims + new_dims = list(dims) + new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0] + return new_dims -# Capture the effect of permute op on incoming dimension order -def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]: +def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: """ - Given a permute node, and the incoming dimension ordering of the input - tensor to the permute node, return the net effect of permute op on the - dimension order. + Applies the permutation as given by node onto the dimensions given in input + e.g (1, 2, 0) on [a, b, c] would return [b, c, a] """ assert node.target == exir_ops.edge.aten.permute_copy.default # Permute each index of the dimension ordering (dims) # pyre-fixme[6]: This combined typecheck isn't supported yet. permute_dims: List[int] = list(node.args[1]) assert all(isinstance(x, int) for x in permute_dims) - # If the dims is empty, we can simply return the permute order - if not dims: - return permute_dims - dims = [dims[x] for x in permute_dims] - return dims + return [dims[x] for x in permute_dims] # Return the tensor of buffer/parameter op diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 0221c9434e1..77184c7af77 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -14,7 +14,7 @@ import operator from collections import deque from numbers import Number -from typing import cast, Sequence +from typing import Any, Callable, cast # Import these for the cadence function signatures. import executorch.backends.cadence.aot.ops_registrations # noqa: F401 @@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass): +class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): """ - Fuse transpose op pairs to a single view op. + Fuse transpose or permute op pairs to a single view op. + (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) """ # A list of ops that can be bypassed when looking for a @@ -907,42 +908,17 @@ def can_fuse_for_chain( if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): return False - def get_dims(node: torch.fx.Node) -> tuple[int, int]: - def canonicalize(dim: int) -> int: - if dim < 0: - dim += len(node.meta["val"].shape) - return dim - - return tuple(canonicalize(cast(int, d)) for d in node.args[1:3]) - - def is_equivalent( - shape: Sequence[int], - transpose0: tuple[int, int], - transpose1: tuple[int, int], - ) -> bool: - def permute_order( - order: Sequence[int], dims: tuple[int, int] - ) -> Sequence[int]: - new_order = list(order) - new_order[dims[0]], new_order[dims[1]] = ( - new_order[dims[1]], - new_order[dims[0]], - ) - return new_order - - order = permute_order(range(len(shape)), transpose0) - order = permute_order(order, transpose1) - - non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1] - non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1] - - return non_unit_dims == non_unit_dims_permuted - - return is_equivalent( - cast(torch.fx.Node, producer.args[0]).meta["val"].shape, - get_dims(producer), - get_dims(consumer), - ) + # checking that permut2(permut1(identify)) == identity + input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape + ident_dims = list(range(len(input_shape))) + # this mapping helps to handle both transpose and permutations + f: dict[Any, Callable] = { + exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, + exir_ops.edge.aten.permute_copy.default: get_permuted_dims, + } + in_dims = f[producer.target](producer, ident_dims) + out_dims = f[consumer.target](consumer, in_dims) + return out_dims == ident_dims def get_fused_node( self, @@ -960,11 +936,17 @@ def get_fused_node( return view def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # Remove any dequantize op that has only quantize ops as its users. + # Remove any transpose/permutation op pair that cancel each other. self.find_and_fuse( graph_module, - producer_op_packets={exir_ops.edge.aten.transpose_copy}, - consumer_op_packets={exir_ops.edge.aten.transpose_copy}, + producer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, + consumer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, bypass_ops=self.bypass_ops, ) result = super().call(graph_module) @@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph: FuseQuantDequantToRequantizePass, FuseMulIntoDequantPass, FuseFullThenReshapePass, - FuseTransposeOpPairsPass, + FuseTransposeOrPermuteOpPairsPass, ] diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 9c47eb4094f..8355f7ef432 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -14,7 +14,7 @@ from executorch.backends.cadence.aot.fuse_ops import ( CadenceFuseOpsInGraph, FuseFullThenReshapePass, - FuseTransposeOpPairsPass, + FuseTransposeOrPermuteOpPairsPass, ) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, @@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]: CadenceSimplifyOpsInGraph.passes, FinalizePipeline, FuseFullThenReshapePass, - FuseTransposeOpPairsPass, + FuseTransposeOrPermuteOpPairsPass, RemoveNopSliceOrViewOpPass, ] return pytree.tree_flatten(passes)[0] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 2c9d56819c0..1bb44b872d2 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -8,6 +8,7 @@ import unittest +from typing import Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -20,7 +21,7 @@ FuseFullThenReshapePass, FuseMulIntoDequantPass, FuseQuantDequantToRequantizePass, - FuseTransposeOpPairsPass, + FuseTransposeOrPermuteOpPairsPass, ) from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match @@ -509,6 +510,24 @@ def test_fuse_then_transpose_pass(self): ) +class TestFuseTransposeOrPermuteOpPairsPass(TestFusionPassesBase): + def _create_operator( + self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue + ) -> ProxyValue: + if op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: + return builder.call_operator( + op=op, + args=(x, 1.2, 3, 0, 127, torch.int8), + ) + elif op == exir_ops.edge.cadence.quantized_relu.per_tensor: + return builder.call_operator( + op=op, + args=(x, 0, 0, 0, 0), + ) + else: + raise ValueError(f"Unsupported op: {op}") + + class TestFuseTransposeOpPairsPass(TestFusionPassesBase): def _create_operator( self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue @@ -528,83 +547,168 @@ def _create_operator( @parameterized.expand( [ - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, + # transpose -> quant -> same transpose => fuse + ( + True, + [0, 1], + True, + [0, 1], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + ), + # same with different input size + ( + True, + [0, 1], + True, + [0, 1], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + [4, 4, 4], + ), + # transpose -> quant -> same transpose => fuse (same with transpose dimensions in different order, and with different skip quant op) + ( + True, + [0, 1], + True, + [1, 0], + exir_ops.edge.cadence.quantized_relu.per_tensor, + True, + ), + # transpose -> quant -> different transpose => don't fuse + ( + True, + [0, 1], + True, + [0, 2], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + ), + # permutation -> quant -> opposite permutation => fuse + ( + False, + [1, 2, 0], + False, + [2, 0, 1], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + ), + # same with different input size + ( + False, + [1, 2, 0], + False, + [2, 0, 1], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + [4, 4, 4], + ), + # permutation -> quant -> not the opposite permutation => don't fuse + ( + False, + [1, 2, 0], + False, + [1, 2, 0], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + ), + # same with different input size + ( + False, + [1, 2, 0], + False, + [1, 2, 0], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + [4, 4, 4], + ), + # transpose -> quant -> transpose as a permutation => fuse + ( + True, + [0, 1], + False, + [1, 0, 2], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + ), + # transpose -> quant -> not opposite permutation => fuse + ( + True, + [0, 1], + False, + [0, 2, 1], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + ), ], ) - def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload): - # Create a graph with transpose -> quant -> transpose. + def test_fuse_transpose_permute_pairs( + self, + is_op1_transpose: bool, + perm1: list[int], + is_op2_transpose: bool, + perm2: list[int], + quant_op: torch._ops.OpOverload, + expected_is_fused: bool, + dims: Tuple[int, int, int] = (2, 3, 4), + ): + # Create a graph with transpose/permute -> quant -> transpose/permute. builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(2, 3)) - transpose_node = builder.call_operator( - op=exir_ops.edge.aten.transpose_copy.int, - args=(x, 0, 1), - ) - quant_node = self._create_operator(builder, op, transpose_node) - transpose_node = builder.call_operator( - op=exir_ops.edge.aten.transpose_copy.int, - args=(quant_node, 0, 1), + x = builder.placeholder("x", torch.randn(dims)) + op1 = ( + exir_ops.edge.aten.transpose_copy.int + if is_op1_transpose + else exir_ops.edge.aten.permute_copy.default + ) + node1 = builder.call_operator( + op=op1, + args=(x, perm1[0], perm1[1]) if is_op1_transpose else (x, list(perm1)), + ) + quant_node = self._create_operator(builder, quant_op, node1) + op2 = ( + exir_ops.edge.aten.transpose_copy.int + if is_op2_transpose + else exir_ops.edge.aten.permute_copy.default + ) + node2 = builder.call_operator( + op=op2, + args=( + (quant_node, perm2[0], perm2[1]) + if is_op2_transpose + else (quant_node, list(perm2)) + ), ) - builder.output([transpose_node]) + builder.output([node2]) gm = builder.get_graph_module() + expected_op_counts = { + quant_op: 1, + } + expected_op_counts[op1] = 1 + expected_op_counts[op2] = expected_op_counts.get(op2, 0) + 1 self.check_op_counts( gm, - expected_op_counts={ - exir_ops.edge.aten.transpose_copy.int: 2, - op: 1, - }, + # pyre-fixme[6]: Incompatible parameter type + expected_op_counts=expected_op_counts, ) - # Check that the pass fuses the two transpose ops. - fusion_pass_result = FuseTransposeOpPairsPass()(gm) + # Check that the pass fuses the two transpose/permute ops. + fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm) self.assertIsNotNone(fusion_pass_result) gm_after_pass = fusion_pass_result.graph_module + if expected_is_fused: + expected_op_counts[op1] = 0 + expected_op_counts[op2] = 0 self.check_op_counts( gm_after_pass, - expected_op_counts={ - exir_ops.edge.aten.transpose_copy.int: 0, - op: 1, - }, - ) - - def test_no_fusion_for_transpose_pairs(self): - # Create a graph with transpose -> quant -> transpose. - builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(2, 3, 4)) - transpose_node = builder.call_operator( - op=exir_ops.edge.aten.transpose_copy.int, - args=(x, 0, 1), - ) - quant_node = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(transpose_node, 1.2, 3, 0, 127, torch.int8), - ) - transpose_node = builder.call_operator( - op=exir_ops.edge.aten.transpose_copy.int, - args=(quant_node, 1, 2), - ) - builder.output(transpose_node) - gm = builder.get_graph_module() - self.check_op_counts( - gm, - expected_op_counts={ - exir_ops.edge.aten.transpose_copy.int: 2, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, - }, - ) - - # No fusion. - gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module - self.check_op_counts( - gm_after_pass, - expected_op_counts={ - exir_ops.edge.aten.transpose_copy.int: 2, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, - }, + # pyre-fixme[6]: Incompatible parameter type + expected_op_counts=expected_op_counts, ) def test_fusion_for_forked_transposes(self): - # Create a graph with transpose -> quant -> transpose. + # Create a graph with + # transpose -> quant -> transpose. + # -> quant -> transpose. + # -> quant -> transpose. builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32)) transpose_node = builder.call_operator( @@ -634,8 +738,8 @@ def test_fusion_for_forked_transposes(self): }, ) - # Fuse the all the transpose ops. - gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module + # Fuse all the transpose ops. + gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module self.check_op_counts( gm_after_pass, expected_op_counts={