diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 77184c7af77..7a20a3f64b4 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): """ Fuse transpose or permute op pairs to a single view op. (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) + This happens when op2(op1) == identity, modulo unitary dimensions. + 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] + so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. """ # A list of ops that can be bypassed when looking for a @@ -908,7 +911,7 @@ def can_fuse_for_chain( if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): return False - # checking that permut2(permut1(identify)) == identity + # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape ident_dims = list(range(len(input_shape))) # this mapping helps to handle both transpose and permutations @@ -918,7 +921,10 @@ def can_fuse_for_chain( } in_dims = f[producer.target](producer, ident_dims) out_dims = f[consumer.target](consumer, in_dims) - return out_dims == ident_dims + # Filtering out unitary dimensions + non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] + non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] + return non_unit_out_dims == non_unit_ident_dims def get_fused_node( self, @@ -926,6 +932,9 @@ def get_fused_node( consumer: torch.fx.Node, graph_module: torch.fx.GraphModule, ) -> torch.fx.Node: + # This step is important because of how we can fuse transpositions that are not perfectly + # reverse one of another but will be fused if there are unitary dimensions. + # The fused operation must have the same output shape as the consumer. output_shape = consumer.meta["val"].shape with graph_module.graph.inserting_after(consumer): view = graph_module.graph.call_function( diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 1bb44b872d2..4e267254488 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -584,6 +584,28 @@ def _create_operator( exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, False, ), + # transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension + # so it ends up being the same on memory => fuse + ( + True, + [0, 1], + True, + [0, 2], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + [5, 40, 1], + ), + # transpose -> quant -> transpose is not the reverse, and unitary dimensions + # don't help => don't fuse + ( + True, + [0, 1], + True, + [1, 3], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + [5, 40, 1, 4], + ), # permutation -> quant -> opposite permutation => fuse ( False, @@ -622,6 +644,28 @@ def _create_operator( False, [4, 4, 4], ), + # permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension + # so it ends up being the same on memory => fuse + ( + False, + [1, 3, 2, 0], + False, + [3, 2, 1, 0], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + True, + [3, 1, 8, 10], + ), + # permutation -> quant -> a non reverse permutation, and unitary dimensions + # don't help => don't fuse + ( + False, + [1, 3, 2, 0], + False, + [3, 1, 2, 0], + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + False, + [3, 1, 8, 10], + ), # transpose -> quant -> transpose as a permutation => fuse ( True,