Skip to content

Commit 2c34c2a

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
Allow removing permute pairs in addition to transpose pairs
Summary: Pull Request resolved: As titled. Gets us 27% better cycles on Activity Classification (at opt level 3). Can be improved further (when fused permutations are not an identity), task is T222295719 Differential Revision: D73619452
1 parent c5dd476 commit 2c34c2a

File tree

4 files changed

+171
-127
lines changed

4 files changed

+171
-127
lines changed

backends/cadence/aot/compiler_utils.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,10 @@ def get_cascaded_ops(
109109
return nodes
110110

111111

112-
# Capture the effect of transpose op on incoming dimension order
113-
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
112+
def get_transposed_dims(node: torch.fx.Node, dims: Optional[List[int]]=None) -> List[int]:
114113
"""
115-
Given a transpose node, and the incoming dimension ordering of the input
116-
tensor to the transpose node, return the net effect of transpose op on the
117-
dimension order.
114+
Applies the transposition as given by node onto the dimensions given in input
115+
e.g (1, 2) on [5, 6, 7, 8] would return [5, 7, 6, 8]
118116
"""
119117
assert node.target == exir_ops.edge.aten.transpose_copy.int
120118
# Assert that the dims is not empty
@@ -127,28 +125,23 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
127125
assert isinstance(transpose_dims1, int)
128126
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
129127
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
130-
# Perform transpose on dimmension ordering (dims)
131-
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
132-
return dims
128+
new_dims = list(dims)
129+
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
130+
return new_dims
133131

134132

135-
# Capture the effect of permute op on incoming dimension order
136-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
133+
134+
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
137135
"""
138-
Given a permute node, and the incoming dimension ordering of the input
139-
tensor to the permute node, return the net effect of permute op on the
140-
dimension order.
136+
Applies the permutation as given by node onto the dimensions given in input
137+
e.g (1, 2, 0) on [5, 6, 7] would return [6, 7, 5]
141138
"""
142139
assert node.target == exir_ops.edge.aten.permute_copy.default
143140
# Permute each index of the dimension ordering (dims)
144141
# pyre-fixme[6]: This combined typecheck isn't supported yet.
145142
permute_dims: List[int] = list(node.args[1])
146143
assert all(isinstance(x, int) for x in permute_dims)
147-
# If the dims is empty, we can simply return the permute order
148-
if not dims:
149-
return permute_dims
150-
dims = [dims[x] for x in permute_dims]
151-
return dims
144+
return [dims[x] for x in permute_dims]
152145

153146

154147
# Return the tensor of buffer/parameter op

backends/cadence/aot/fuse_ops.py

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887+
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888
"""
888889

889890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,19 @@ def can_fuse_for_chain(
907908
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908909
return False
909910

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
932-
933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
935-
936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
938-
939-
return non_unit_dims == non_unit_dims_permuted
940-
941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
911+
# checking that permut2(permut1(identify)) == identity
912+
# pyre-fixme[16]: `None` has no attribute `meta`
913+
ident_dims = list(range(len(producer.args[0].meta["val"].shape)))
914+
# this mapping helps to handle both transpose and permutations
915+
f = {
916+
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
917+
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
918+
}
919+
# pyre-fixme[29]: VT is not a function
920+
in_dims = f[producer.target](producer, ident_dims)
921+
# pyre-fixme[29]: VT is not a function
922+
out_dims = f[consumer.target](consumer, in_dims)
923+
return out_dims == ident_dims
946924

947925
def get_fused_node(
948926
self,
@@ -960,11 +938,11 @@ def get_fused_node(
960938
return view
961939

962940
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
941+
# Remove any transpose/permutation op pair that cancel each other.
964942
self.find_and_fuse(
965943
graph_module,
966-
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967-
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
944+
producer_op_packets={exir_ops.edge.aten.transpose_copy, exir_ops.edge.aten.permute_copy},
945+
consumer_op_packets={exir_ops.edge.aten.transpose_copy, exir_ops.edge.aten.permute_copy},
968946
bypass_ops=self.bypass_ops,
969947
)
970948
result = super().call(graph_module)
@@ -1028,5 +1006,5 @@ class CadenceFuseOpsInGraph:
10281006
FuseQuantDequantToRequantizePass,
10291007
FuseMulIntoDequantPass,
10301008
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1009+
FuseTransposeOrPermuteOpPairsPass,
10321010
]

backends/cadence/aot/passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

0 commit comments

Comments
 (0)