Skip to content

Commit 2fbf300

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
Allow removing permute pairs in addition to transpose pairs (#10501)
Summary: Pull Request resolved: #10501 As titled. Gets us 27% better cycles on Activity Classification (at opt level 3). Can be improved further, task is T222295719 Differential Revision: D73619452
1 parent 32dffbc commit 2fbf300

File tree

4 files changed

+101
-128
lines changed

4 files changed

+101
-128
lines changed

backends/cadence/aot/compiler_utils.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,10 @@ def get_cascaded_ops(
109109
return nodes
110110

111111

112-
# Capture the effect of transpose op on incoming dimension order
113-
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
112+
def get_transposed_dims(node: torch.fx.Node, dims: Optional[List[int]]=None) -> List[int]:
114113
"""
115-
Given a transpose node, and the incoming dimension ordering of the input
116-
tensor to the transpose node, return the net effect of transpose op on the
117-
dimension order.
114+
Applies the transposition as given by node onto the dimensions given in input
115+
e.g (1, 2) on [5, 6, 7, 8] would return [5, 7, 6, 8]
118116
"""
119117
assert node.target == exir_ops.edge.aten.transpose_copy.int
120118
# Assert that the dims is not empty
@@ -127,28 +125,23 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
127125
assert isinstance(transpose_dims1, int)
128126
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
129127
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
130-
# Perform transpose on dimmension ordering (dims)
131-
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
132-
return dims
128+
new_dims = list(dims)
129+
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
130+
return new_dims
133131

134132

135-
# Capture the effect of permute op on incoming dimension order
136-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
133+
134+
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
137135
"""
138-
Given a permute node, and the incoming dimension ordering of the input
139-
tensor to the permute node, return the net effect of permute op on the
140-
dimension order.
136+
Applies the permutation as given by node onto the dimensions given in input
137+
e.g (1, 2, 0) on [5, 6, 7] would return [6, 7, 5]
141138
"""
142139
assert node.target == exir_ops.edge.aten.permute_copy.default
143140
# Permute each index of the dimension ordering (dims)
144141
# pyre-fixme[6]: This combined typecheck isn't supported yet.
145142
permute_dims: List[int] = list(node.args[1])
146143
assert all(isinstance(x, int) for x in permute_dims)
147-
# If the dims is empty, we can simply return the permute order
148-
if not dims:
149-
return permute_dims
150-
dims = [dims[x] for x in permute_dims]
151-
return dims
144+
return [dims[x] for x in permute_dims]
152145

153146

154147
# Return the tensor of buffer/parameter op

backends/cadence/aot/fuse_ops.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887+
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888
"""
888889

889890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907908
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908909
return False
909910

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
932-
933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
935-
936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
938-
939-
return non_unit_dims == non_unit_dims_permuted
940-
941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
911+
# pyre-fixme[16]: `None` has no attribute `meta`
912+
ident_dims = list(range(len(producer.args[0].meta["val"].shape)))
913+
f = {
914+
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
915+
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
916+
}
917+
# pyre-fixme[29]: VT is not a function
918+
in_dims = f[producer.target](producer, ident_dims)
919+
# pyre-fixme[29]: VT is not a function
920+
out_dims = f[consumer.target](consumer, in_dims)
921+
return out_dims == ident_dims
946922

947923
def get_fused_node(
948924
self,
@@ -960,11 +936,11 @@ def get_fused_node(
960936
return view
961937

962938
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
939+
# Remove any transpose/permutation op pair that cancel each other.
964940
self.find_and_fuse(
965941
graph_module,
966-
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967-
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
942+
producer_op_packets={exir_ops.edge.aten.transpose_copy, exir_ops.edge.aten.permute_copy},
943+
consumer_op_packets={exir_ops.edge.aten.transpose_copy, exir_ops.edge.aten.permute_copy},
968944
bypass_ops=self.bypass_ops,
969945
)
970946
result = super().call(graph_module)
@@ -1028,5 +1004,5 @@ class CadenceFuseOpsInGraph:
10281004
FuseQuantDequantToRequantizePass,
10291005
FuseMulIntoDequantPass,
10301006
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1007+
FuseTransposeOrPermuteOpPairsPass,
10321008
]

backends/cadence/aot/passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import unittest
11+
from typing import List
1112

1213
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
@@ -20,7 +21,7 @@
2021
FuseFullThenReshapePass,
2122
FuseMulIntoDequantPass,
2223
FuseQuantDequantToRequantizePass,
23-
FuseTransposeOpPairsPass,
24+
FuseTransposeOrPermuteOpPairsPass,
2425
)
2526
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
2627
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
@@ -509,6 +510,24 @@ def test_fuse_then_transpose_pass(self):
509510
)
510511

511512

513+
class TestFuseTransposeOrPermuteOpPairsPass(TestFusionPassesBase):
514+
def _create_operator(
515+
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
516+
) -> ProxyValue:
517+
if op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
518+
return builder.call_operator(
519+
op=op,
520+
args=(x, 1.2, 3, 0, 127, torch.int8),
521+
)
522+
elif op == exir_ops.edge.cadence.quantized_relu.per_tensor:
523+
return builder.call_operator(
524+
op=op,
525+
args=(x, 0, 0, 0, 0),
526+
)
527+
else:
528+
raise ValueError(f"Unsupported op: {op}")
529+
530+
512531
class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
513532
def _create_operator(
514533
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
@@ -528,83 +547,68 @@ def _create_operator(
528547

529548
@parameterized.expand(
530549
[
531-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
532-
exir_ops.edge.cadence.quantized_relu.per_tensor,
550+
# transpose -> quant -> same transpose => fuse
551+
(True, [0, 1], True, [0, 1], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, True),
552+
# transpose -> quant -> same transpose => fuse (same with transpose dimensions in different order, and with different skip quant op)
553+
(True, [0, 1], True, [1, 0], exir_ops.edge.cadence.quantized_relu.per_tensor, True),
554+
# transpose -> quant -> different transpose => don't fuse
555+
(True, [0, 1], True, [0, 2], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, False),
556+
# permutation -> quant -> opposite permutation => fuse
557+
(False, [1, 2, 0], False, [2, 0, 1], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, True),
558+
# permutation -> quant -> not the opposite permutation => don't fuse
559+
(False, [1, 2, 0], False, [1, 2, 0], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, False),
560+
# transpose -> quant -> transpose as a permutation => fuse
561+
(True, [0, 1], False, [1, 0, 2], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, True),
562+
# transpose -> quant -> not opposite permutation => fuse
563+
(True, [0, 1], False, [0, 2, 1], exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, False),
533564
],
534565
)
535-
def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
536-
# Create a graph with transpose -> quant -> transpose.
566+
def test_fuse_transpose_permute_pairs(self, is_op1_transpose: bool, perm1: List[int], is_op2_transpose: bool, perm2: List[int], quant_op: torch._ops.OpOverload, expected_is_fused: bool):
567+
# Create a graph with transpose/permute -> quant -> transpose/permute.
537568
builder = GraphBuilder()
538-
x = builder.placeholder("x", torch.randn(2, 3))
539-
transpose_node = builder.call_operator(
540-
op=exir_ops.edge.aten.transpose_copy.int,
541-
args=(x, 0, 1),
542-
)
543-
quant_node = self._create_operator(builder, op, transpose_node)
544-
transpose_node = builder.call_operator(
545-
op=exir_ops.edge.aten.transpose_copy.int,
546-
args=(quant_node, 0, 1),
547-
)
548-
builder.output([transpose_node])
569+
x = builder.placeholder("x", torch.randn(2, 3, 4))
570+
op1 = exir_ops.edge.aten.transpose_copy.int if is_op1_transpose else exir_ops.edge.aten.permute_copy.default
571+
node1 = builder.call_operator(
572+
op=op1,
573+
args=(x, perm1[0], perm1[1]) if is_op1_transpose else (x, list(perm1)),
574+
)
575+
quant_node = self._create_operator(builder, quant_op, node1)
576+
op2 = exir_ops.edge.aten.transpose_copy.int if is_op2_transpose else exir_ops.edge.aten.permute_copy.default
577+
node2 = builder.call_operator(
578+
op=op2,
579+
args=(quant_node, perm2[0], perm2[1]) if is_op2_transpose else (quant_node, list(perm2)),
580+
)
581+
builder.output([node2])
549582
gm = builder.get_graph_module()
583+
exp_counts = {
584+
quant_op: 1,
585+
}
586+
exp_counts[op1] = 1
587+
exp_counts[op2] = exp_counts.get(op2, 0) + 1
550588
self.check_op_counts(
551589
gm,
552-
expected_op_counts={
553-
exir_ops.edge.aten.transpose_copy.int: 2,
554-
op: 1,
555-
},
590+
# pyre-fixme[6]: Incompatible parameter type
591+
expected_op_counts=exp_counts
556592
)
557593

558-
# Check that the pass fuses the two transpose ops.
559-
fusion_pass_result = FuseTransposeOpPairsPass()(gm)
594+
# Check that the pass fuses the two transpose/permute ops.
595+
fusion_pass_result = FuseTransposeOrPermuteOpPairsPass()(gm)
560596
self.assertIsNotNone(fusion_pass_result)
561597
gm_after_pass = fusion_pass_result.graph_module
598+
if expected_is_fused:
599+
exp_counts[op1] = 0
600+
exp_counts[op2] = 0
562601
self.check_op_counts(
563602
gm_after_pass,
564-
expected_op_counts={
565-
exir_ops.edge.aten.transpose_copy.int: 0,
566-
op: 1,
567-
},
568-
)
569-
570-
def test_no_fusion_for_transpose_pairs(self):
571-
# Create a graph with transpose -> quant -> transpose.
572-
builder = GraphBuilder()
573-
x = builder.placeholder("x", torch.randn(2, 3, 4))
574-
transpose_node = builder.call_operator(
575-
op=exir_ops.edge.aten.transpose_copy.int,
576-
args=(x, 0, 1),
577-
)
578-
quant_node = builder.call_operator(
579-
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
580-
args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
581-
)
582-
transpose_node = builder.call_operator(
583-
op=exir_ops.edge.aten.transpose_copy.int,
584-
args=(quant_node, 1, 2),
585-
)
586-
builder.output(transpose_node)
587-
gm = builder.get_graph_module()
588-
self.check_op_counts(
589-
gm,
590-
expected_op_counts={
591-
exir_ops.edge.aten.transpose_copy.int: 2,
592-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
593-
},
594-
)
595-
596-
# No fusion.
597-
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
598-
self.check_op_counts(
599-
gm_after_pass,
600-
expected_op_counts={
601-
exir_ops.edge.aten.transpose_copy.int: 2,
602-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
603-
},
603+
# pyre-fixme[6]: Incompatible parameter type
604+
expected_op_counts=exp_counts,
604605
)
605606

606607
def test_fusion_for_forked_transposes(self):
607-
# Create a graph with transpose -> quant -> transpose.
608+
# Create a graph with
609+
# transpose -> quant -> transpose.
610+
# -> quant -> transpose.
611+
# -> quant -> transpose.
608612
builder = GraphBuilder()
609613
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
610614
transpose_node = builder.call_operator(
@@ -634,8 +638,8 @@ def test_fusion_for_forked_transposes(self):
634638
},
635639
)
636640

637-
# Fuse the all the transpose ops.
638-
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
641+
# Fuse all the transpose ops.
642+
gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module
639643
self.check_op_counts(
640644
gm_after_pass,
641645
expected_op_counts={

0 commit comments

Comments
 (0)