Skip to content

Commit 3f34709

Browse files
authored
More passes updated to be more efficient and correctly set their modified bit
Differential Revision: D87812526 Pull Request resolved: #16044
1 parent e1fae4e commit 3f34709

File tree

4 files changed

+269
-177
lines changed

4 files changed

+269
-177
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from executorch.backends.cadence.aot.pass_utils import (
3535
CadencePassAttribute,
3636
register_cadence_pass,
37+
RemoveOrReplacePassInterface,
3738
)
3839
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3940
from executorch.exir.dialects._ops import ops as exir_ops
@@ -454,7 +455,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
454455

455456

456457
@register_cadence_pass(CadencePassAttribute(opt_level=1))
457-
class FuseCascadedTransposeOrPermuteOps(ExportPass):
458+
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
458459
"""
459460
Fuse a cascaded chain of transpose and permute ops
460461
"""
@@ -464,63 +465,61 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass):
464465
exir_ops.edge.aten.permute_copy.default,
465466
}
466467

467-
# Find a chain of transpose or permute ops, and fuse them into a single permute op.
468+
@property
469+
def targets(self) -> list[EdgeOpOverload]:
470+
return list(self.transpose_or_permute_target)
468471

469-
def fuse_cascaded_transpose_or_permute_ops(
470-
self, graph_module: torch.fx.GraphModule
471-
):
472-
graph = graph_module.graph
473-
for node in graph.nodes:
474-
# We are only interested in permute/transpose ops
475-
if node.target not in self.transpose_or_permute_target:
476-
continue
477-
# Get the cascaded chain of transpose/permute ops starting at node
478-
cascaded_transpose_or_permute_ops = get_cascaded_ops(
479-
[node], self.transpose_or_permute_target
472+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
473+
# Get the cascaded chain of transpose/permute ops starting at node
474+
cascaded_transpose_or_permute_ops = get_cascaded_ops(
475+
[node], self.transpose_or_permute_target
476+
)
477+
# The chain must have more than 1 node
478+
if len(cascaded_transpose_or_permute_ops) == 1:
479+
return False
480+
481+
# Get shape from node metadata
482+
val = node.meta.get("val")
483+
if val is None:
484+
return False
485+
out_shape = val.shape
486+
out_dims = len(out_shape)
487+
488+
# This is the trivial dimension order
489+
dims = list(range(out_dims))
490+
# Compute the effect of the chain on dims
491+
for tp in cascaded_transpose_or_permute_ops:
492+
dims = (
493+
get_transposed_dims(tp, dims)
494+
if tp.target == exir_ops.edge.aten.transpose_copy.int
495+
else get_permuted_dims(tp, dims)
480496
)
481-
# The chain must have more than 1 node
482-
if len(cascaded_transpose_or_permute_ops) == 1:
483-
continue
484497

485-
out_shape = get_shape(graph_module, node)
486-
assert out_shape is not None
487-
out_dims = len(out_shape)
488-
# This is the trivial dimension order
489-
dims = list(range(out_dims))
490-
# Compute the effect of the chain on dims
491-
for tp in cascaded_transpose_or_permute_ops:
492-
dims = (
493-
get_transposed_dims(tp, dims)
494-
if tp.target == exir_ops.edge.aten.transpose_copy.int
495-
else get_permuted_dims(tp, dims)
496-
)
498+
graph = node.graph
497499

498-
# In case the permute chain cancelled each other, the final dims will
499-
# be the same as the initial order. In that case, the chain was nop.
500-
# Otherwise create a new permute op that encompasses the effect of the
501-
# chain.
502-
if dims == list(range(out_dims)):
503-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
504-
node.args[0]
500+
# In case the permute chain cancelled each other, the final dims will
501+
# be the same as the initial order. In that case, the chain was nop.
502+
# Otherwise create a new permute op that encompasses the effect of the
503+
# chain.
504+
if dims == list(range(out_dims)):
505+
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
506+
cast(torch.fx.Node, node.args[0])
507+
)
508+
else:
509+
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
510+
new_permute = graph.call_function(
511+
exir_ops.edge.aten.permute_copy.default,
512+
args=(node.args[0], dims),
505513
)
506-
else:
507-
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
508-
new_permute = graph.call_function(
509-
exir_ops.edge.aten.permute_copy.default,
510-
args=(node.args[0], dims),
511-
)
512-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
514+
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
515+
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
513516

514-
# Now erase the chain
515-
for tp in reversed(cascaded_transpose_or_permute_ops):
516-
graph.erase_node(tp)
517-
518-
graph_module.recompile()
517+
# Now erase the chain (except the first node which will be handled by the interface)
518+
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
519+
graph.erase_node(tp)
519520

520-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
521-
self.fuse_cascaded_transpose_or_permute_ops(graph_module)
522-
result = super().call(graph_module)
523-
return result
521+
# Return True to indicate the first node in the chain should be removed
522+
return True
524523

525524

526525
@register_cadence_pass(CadencePassAttribute(opt_level=1))

0 commit comments

Comments
 (0)