Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def get_cascaded_ops(
return nodes


# Capture the effect of transpose op on incoming dimension order
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
def get_transposed_dims(
node: torch.fx.Node, dims: Optional[List[int]] = None
) -> List[int]:
"""
Given a transpose node, and the incoming dimension ordering of the input
tensor to the transpose node, return the net effect of transpose op on the
dimension order.
Applies the transposition as given by node onto the dimensions given in input
e.g (1, 2) on [5, 6, 7, 8] would return [5, 7, 6, 8]
"""
assert node.target == exir_ops.edge.aten.transpose_copy.int
# Assert that the dims is not empty
Expand All @@ -127,28 +127,22 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
assert isinstance(transpose_dims1, int)
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
# Perform transpose on dimmension ordering (dims)
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return dims
new_dims = list(dims)
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
return new_dims


# Capture the effect of permute op on incoming dimension order
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
"""
Given a permute node, and the incoming dimension ordering of the input
tensor to the permute node, return the net effect of permute op on the
dimension order.
Applies the permutation as given by node onto the dimensions given in input
e.g (1, 2, 0) on [5, 6, 7] would return [6, 7, 5]
"""
assert node.target == exir_ops.edge.aten.permute_copy.default
# Permute each index of the dimension ordering (dims)
# pyre-fixme[6]: This combined typecheck isn't supported yet.
permute_dims: List[int] = list(node.args[1])
assert all(isinstance(x, int) for x in permute_dims)
# If the dims is empty, we can simply return the permute order
if not dims:
return permute_dims
dims = [dims[x] for x in permute_dims]
return dims
return [dims[x] for x in permute_dims]


# Return the tensor of buffer/parameter op
Expand Down
70 changes: 27 additions & 43 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import operator
from collections import deque
from numbers import Number
from typing import cast, Sequence
from typing import cast

# Import these for the cadence function signatures.
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
Expand Down Expand Up @@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
"""
Fuse transpose op pairs to a single view op.
Fuse transpose or permute op pairs to a single view op.
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
"""

# A list of ops that can be bypassed when looking for a
Expand All @@ -907,42 +908,19 @@ def can_fuse_for_chain(
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
return False

def get_dims(node: torch.fx.Node) -> tuple[int, int]:
def canonicalize(dim: int) -> int:
if dim < 0:
dim += len(node.meta["val"].shape)
return dim

return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])

def is_equivalent(
shape: Sequence[int],
transpose0: tuple[int, int],
transpose1: tuple[int, int],
) -> bool:
def permute_order(
order: Sequence[int], dims: tuple[int, int]
) -> Sequence[int]:
new_order = list(order)
new_order[dims[0]], new_order[dims[1]] = (
new_order[dims[1]],
new_order[dims[0]],
)
return new_order

order = permute_order(range(len(shape)), transpose0)
order = permute_order(order, transpose1)

non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]

return non_unit_dims == non_unit_dims_permuted

return is_equivalent(
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
get_dims(producer),
get_dims(consumer),
)
# checking that permut2(permut1(identify)) == identity
# pyre-fixme[16]: `None` has no attribute `meta`
ident_dims = list(range(len(producer.args[0].meta["val"].shape)))
# this mapping helps to handle both transpose and permutations
f = {
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
}
# pyre-fixme[29]: VT is not a function
in_dims = f[producer.target](producer, ident_dims)
# pyre-fixme[29]: VT is not a function
out_dims = f[consumer.target](consumer, in_dims)
return out_dims == ident_dims

def get_fused_node(
self,
Expand All @@ -960,11 +938,17 @@ def get_fused_node(
return view

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Remove any dequantize op that has only quantize ops as its users.
# Remove any transpose/permutation op pair that cancel each other.
self.find_and_fuse(
graph_module,
producer_op_packets={exir_ops.edge.aten.transpose_copy},
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
producer_op_packets={
exir_ops.edge.aten.transpose_copy,
exir_ops.edge.aten.permute_copy,
},
consumer_op_packets={
exir_ops.edge.aten.transpose_copy,
exir_ops.edge.aten.permute_copy,
},
bypass_ops=self.bypass_ops,
)
result = super().call(graph_module)
Expand Down Expand Up @@ -1028,5 +1012,5 @@ class CadenceFuseOpsInGraph:
FuseQuantDequantToRequantizePass,
FuseMulIntoDequantPass,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
]
4 changes: 2 additions & 2 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.backends.cadence.aot.fuse_ops import (
CadenceFuseOpsInGraph,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
)
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
CadenceSimplifyOpsInGraph.passes,
FinalizePipeline,
FuseFullThenReshapePass,
FuseTransposeOpPairsPass,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
]
return pytree.tree_flatten(passes)[0]
Expand Down
Loading
Loading