3434from executorch .backends .cadence .aot .pass_utils import (
3535 CadencePassAttribute ,
3636 register_cadence_pass ,
37+ RemoveOrReplacePassInterface ,
3738)
3839from executorch .backends .cadence .aot .utils import get_edge_overload_packet
3940from executorch .exir .dialects ._ops import ops as exir_ops
@@ -454,7 +455,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
454455
455456
456457@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
457- class FuseCascadedTransposeOrPermuteOps (ExportPass ):
458+ class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
458459 """
459460 Fuse a cascaded chain of transpose and permute ops
460461 """
@@ -464,63 +465,61 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass):
464465 exir_ops .edge .aten .permute_copy .default ,
465466 }
466467
467- # Find a chain of transpose or permute ops, and fuse them into a single permute op.
468+ @property
469+ def targets (self ) -> list [EdgeOpOverload ]:
470+ return list (self .transpose_or_permute_target )
468471
469- def fuse_cascaded_transpose_or_permute_ops (
470- self , graph_module : torch .fx .GraphModule
471- ):
472- graph = graph_module .graph
473- for node in graph .nodes :
474- # We are only interested in permute/transpose ops
475- if node .target not in self .transpose_or_permute_target :
476- continue
477- # Get the cascaded chain of transpose/permute ops starting at node
478- cascaded_transpose_or_permute_ops = get_cascaded_ops (
479- [node ], self .transpose_or_permute_target
472+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
473+ # Get the cascaded chain of transpose/permute ops starting at node
474+ cascaded_transpose_or_permute_ops = get_cascaded_ops (
475+ [node ], self .transpose_or_permute_target
476+ )
477+ # The chain must have more than 1 node
478+ if len (cascaded_transpose_or_permute_ops ) == 1 :
479+ return False
480+
481+ # Get shape from node metadata
482+ val = node .meta .get ("val" )
483+ if val is None :
484+ return False
485+ out_shape = val .shape
486+ out_dims = len (out_shape )
487+
488+ # This is the trivial dimension order
489+ dims = list (range (out_dims ))
490+ # Compute the effect of the chain on dims
491+ for tp in cascaded_transpose_or_permute_ops :
492+ dims = (
493+ get_transposed_dims (tp , dims )
494+ if tp .target == exir_ops .edge .aten .transpose_copy .int
495+ else get_permuted_dims (tp , dims )
480496 )
481- # The chain must have more than 1 node
482- if len (cascaded_transpose_or_permute_ops ) == 1 :
483- continue
484497
485- out_shape = get_shape (graph_module , node )
486- assert out_shape is not None
487- out_dims = len (out_shape )
488- # This is the trivial dimension order
489- dims = list (range (out_dims ))
490- # Compute the effect of the chain on dims
491- for tp in cascaded_transpose_or_permute_ops :
492- dims = (
493- get_transposed_dims (tp , dims )
494- if tp .target == exir_ops .edge .aten .transpose_copy .int
495- else get_permuted_dims (tp , dims )
496- )
498+ graph = node .graph
497499
498- # In case the permute chain cancelled each other, the final dims will
499- # be the same as the initial order. In that case, the chain was nop.
500- # Otherwise create a new permute op that encompasses the effect of the
501- # chain.
502- if dims == list (range (out_dims )):
503- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (
504- node .args [0 ]
500+ # In case the permute chain cancelled each other, the final dims will
501+ # be the same as the initial order. In that case, the chain was nop.
502+ # Otherwise create a new permute op that encompasses the effect of the
503+ # chain.
504+ if dims == list (range (out_dims )):
505+ cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (
506+ cast (torch .fx .Node , node .args [0 ])
507+ )
508+ else :
509+ with graph .inserting_before (cascaded_transpose_or_permute_ops [- 1 ]):
510+ new_permute = graph .call_function (
511+ exir_ops .edge .aten .permute_copy .default ,
512+ args = (node .args [0 ], dims ),
505513 )
506- else :
507- with graph .inserting_before (cascaded_transpose_or_permute_ops [- 1 ]):
508- new_permute = graph .call_function (
509- exir_ops .edge .aten .permute_copy .default ,
510- args = (node .args [0 ], dims ),
511- )
512- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (new_permute )
514+ new_permute .meta = cascaded_transpose_or_permute_ops [- 1 ].meta
515+ cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (new_permute )
513516
514- # Now erase the chain
515- for tp in reversed (cascaded_transpose_or_permute_ops ):
516- graph .erase_node (tp )
517-
518- graph_module .recompile ()
517+ # Now erase the chain (except the first node which will be handled by the interface)
518+ for tp in reversed (cascaded_transpose_or_permute_ops [1 :]):
519+ graph .erase_node (tp )
519520
520- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
521- self .fuse_cascaded_transpose_or_permute_ops (graph_module )
522- result = super ().call (graph_module )
523- return result
521+ # Return True to indicate the first node in the chain should be removed
522+ return True
524523
525524
526525@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
0 commit comments