|
25 | 25 | import torch.fx |
26 | 26 | from executorch.backends.cadence.aot.pass_utils import ( |
27 | 27 | CadencePassAttribute, |
| 28 | + get_arg, |
28 | 29 | register_cadence_pass, |
| 30 | + set_arg, |
29 | 31 | ) |
30 | 32 |
|
31 | 33 | from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass |
|
37 | 39 | from executorch.exir.pass_manager import PassManager, PassType |
38 | 40 | from executorch.exir.passes import dead_code_elimination_pass |
39 | 41 | from executorch.exir.passes.spec_prop_pass import SpecPropPass |
40 | | -from torch.fx.node import Argument |
| 42 | +from torch.fx.node import Argument, Node |
41 | 43 |
|
42 | 44 |
|
43 | 45 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -771,65 +773,52 @@ def remove_branched( |
771 | 773 |
|
772 | 774 |
|
773 | 775 | class RemoveCatFromSliceCopyPass(ExportPass): |
774 | | - def _remove_unused_cat( # noqa: C901 |
775 | | - self, graph_module: torch.fx.GraphModule |
776 | | - ) -> None: |
777 | | - slice_copy_nodes = [ |
778 | | - node |
779 | | - for node in graph_module.graph.nodes |
780 | | - if node.target == exir_ops.edge.aten.slice_copy.Tensor |
781 | | - ] |
782 | | - for slice_copy_node in slice_copy_nodes: |
783 | | - slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1 |
784 | | - input_node, *other_args = slice_copy_node.args |
785 | | - if len(other_args) >= 1: |
786 | | - slice_dim = other_args[0] |
787 | | - if len(other_args) >= 2: |
788 | | - start_idx = other_args[1] |
789 | | - if len(other_args) >= 3: |
790 | | - end_idx = other_args[2] |
791 | | - if len(other_args) >= 4: |
792 | | - step = other_args[3] |
793 | | - if step != 1: |
794 | | - continue |
795 | | - slice_copy_dtype = slice_copy_node.meta["val"].dtype |
796 | | - if input_node.target != exir_ops.edge.aten.cat.default: |
797 | | - continue |
798 | | - cat_dtype = input_node.meta["val"].dtype |
799 | | - if slice_copy_dtype != cat_dtype: |
| 776 | + """ |
| 777 | + Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed |
| 778 | + to the slice_copy. |
| 779 | + """ |
| 780 | + |
| 781 | + def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: |
| 782 | + for slice_copy_node in graph_module.graph.find_nodes( |
| 783 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 784 | + ): |
| 785 | + cat_node = cast(Node, get_arg(slice_copy_node, 0, "input")) |
| 786 | + slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0)) |
| 787 | + start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None)) |
| 788 | + end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None)) |
| 789 | + step = cast(int, get_arg(slice_copy_node, 4, "step", default=1)) |
| 790 | + |
| 791 | + if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: |
800 | 792 | continue |
801 | | - cat_dim = input_node.args[1:] |
802 | | - if len(cat_dim) == 0: |
803 | | - cat_dim = 0 |
| 793 | + |
| 794 | + # Make sure cat and slice happens on the same dimension. |
| 795 | + cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0)) |
804 | 796 | if cat_dim != slice_dim: |
805 | 797 | continue |
806 | | - cat_output_shape = input_node.meta["val"].shape |
807 | | - start_idx = ( |
808 | | - cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx |
809 | | - ) |
810 | | - end_idx = ( |
811 | | - cat_output_shape[cat_dim] |
812 | | - if end_idx > cat_output_shape[cat_dim] |
813 | | - else end_idx |
814 | | - ) |
815 | | - base_idx = 0 |
816 | | - cat_input_to_keep = None |
817 | | - for cat_input_node in input_node.args[0]: |
818 | | - cat_input_dtype = cat_input_node.meta["val"].dtype |
819 | | - if slice_copy_dtype != cat_input_dtype: |
820 | | - continue |
| 798 | + |
| 799 | + # Canonicalize slice indices. |
| 800 | + cat_output_shape = cat_node.meta["val"].shape |
| 801 | + if start_idx is None: |
| 802 | + start_idx = 0 |
| 803 | + elif start_idx < 0: |
| 804 | + start_idx += cat_output_shape[cat_dim] |
| 805 | + if end_idx is None or end_idx > cat_output_shape[cat_dim]: |
| 806 | + end_idx = cat_output_shape[cat_dim] |
| 807 | + elif end_idx < 0: |
| 808 | + end_idx += cat_output_shape[cat_dim] |
| 809 | + |
| 810 | + offset = 0 |
| 811 | + for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")): |
821 | 812 | cat_input_shape = cat_input_node.meta["val"].shape |
822 | 813 |
|
823 | | - # check if the slice range overlaps with the cat range |
824 | | - if ( |
825 | | - base_idx <= start_idx |
826 | | - and end_idx <= list(cat_input_shape)[cat_dim] + base_idx |
827 | | - ): |
828 | | - cat_input_to_keep = cat_input_node |
| 814 | + # Check if the slice range overlaps with the cat input range. |
| 815 | + if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: |
| 816 | + slice_copy_node.replace_input_with(cat_node, cat_input_node) |
| 817 | + set_arg(slice_copy_node, 2, "start", start_idx - offset) |
| 818 | + set_arg(slice_copy_node, 3, "end", end_idx - offset) |
829 | 819 | break |
830 | | - base_idx += list(cat_input_shape)[cat_dim] |
831 | | - if cat_input_to_keep is not None: |
832 | | - slice_copy_node.replace_input_with(input_node, cat_input_to_keep) |
| 820 | + |
| 821 | + offset += cat_input_shape[cat_dim] |
833 | 822 |
|
834 | 823 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
835 | 824 | self._remove_unused_cat(graph_module) |
|
0 commit comments