|
17 | 17 | # in a context outside of Jarvis', so exercise caution while invoking this in a |
18 | 18 | # pass list outside of Jarvis. |
19 | 19 |
|
20 | | -import itertools |
21 | 20 | import logging |
22 | 21 | from dataclasses import dataclass, field |
23 | | -from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union |
| 22 | +from typing import cast, List, Optional, Sequence |
24 | 23 |
|
25 | 24 | import torch |
26 | 25 | import torch.fx |
@@ -538,211 +537,175 @@ def call_operator( |
538 | 537 | return super().call_operator(op, args, kwargs, meta) |
539 | 538 |
|
540 | 539 |
|
541 | | -@register_cadence_pass(CadencePassAttribute(opt_level=1)) |
| 540 | +@register_cadence_pass(CadencePassAttribute(opt_level=2)) |
542 | 541 | class RemovePermutesAroundElementwiseOps(ExportPass): |
543 | 542 | """ |
544 | 543 | Looks for subgraphs of elementwise ops sandwiched between permutes and removes those |
545 | | - permutes if possible. This pass is targeted at models where delegated subgraphs |
546 | | - must be in NHWC format, so there's usually a to_NHWC permute before each delegate and |
547 | | - a to_NCHW permute after it. If all the ops between two delegates are elementwise ops |
548 | | - then these permutes can be safely removed. |
549 | | - Allows special handling for certain non-elementwise ops that can be easily updated based on |
550 | | - the permute's parameter, such as mean and cat |
| 544 | + permutes if possible. |
| 545 | + Allows special handling for certain non-elementwise ops that can be easily updated |
| 546 | + based on the permute's parameter such as mean, cat, and slice. |
551 | 547 | """ |
552 | 548 |
|
553 | 549 | @dataclass() |
554 | 550 | class Subgraph: |
555 | | - """ |
556 | | - Keeps track of nodes grouped as a subgraph between two sets of permutes |
557 | | - """ |
558 | | - |
559 | | - start_permutes: set[torch.fx.Node] = field(default_factory=set) |
560 | | - end_permutes: set[torch.fx.Node] = field(default_factory=set) |
561 | | - intermediate_nodes: set[torch.fx.Node] = field(default_factory=set) |
562 | | - is_valid: bool = True |
563 | | - |
564 | | - elementwise_ops: set[EdgeOpOverload] = { |
| 551 | + start_permute: list[int] |
| 552 | + end_permute: list[int] |
| 553 | + # Nodes in the subgraph, does not include permutes. |
| 554 | + nodes: set[torch.fx.Node] = field(default_factory=set) |
| 555 | + # Incoming edges to the subgraph from permute nodes. |
| 556 | + edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) |
| 557 | + # Outgoing edges of the subgraph to permute nodes. |
| 558 | + edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) |
| 559 | + |
| 560 | + permutable_ops: set[EdgeOpOverload] = { |
565 | 561 | exir_ops.edge.aten.add.Tensor, |
566 | 562 | exir_ops.edge.aten.mul.Tensor, |
567 | | - exir_ops.edge.aten.mean.dim, |
568 | | - exir_ops.edge.aten.cat.default, |
569 | 563 | exir_ops.edge.aten.hardtanh.default, |
570 | 564 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
571 | 565 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
572 | 566 | exir_ops.edge.cadence.quantize_per_tensor.default, |
573 | 567 | exir_ops.edge.cadence.dequantize_per_tensor.default, |
| 568 | + # Ops that require special handling. |
| 569 | + exir_ops.edge.aten.cat.default, |
| 570 | + exir_ops.edge.aten.mean.dim, |
| 571 | + exir_ops.edge.aten.slice_copy.Tensor, |
574 | 572 | } |
575 | 573 |
|
576 | | - # must be initialized in the constructor |
577 | | - special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {} |
578 | | - |
579 | | - to_NCHW = [0, 3, 1, 2] |
580 | | - to_NHWC = [0, 2, 3, 1] |
581 | | - |
582 | | - def __init__(self) -> None: |
583 | | - super().__init__() |
584 | | - self.visited: set[object] = set() |
585 | | - self.special_handling = { |
586 | | - exir_ops.edge.aten.mean.dim: self.handle_mean_dim, |
587 | | - exir_ops.edge.aten.cat.default: self.handle_cat, |
588 | | - } |
589 | | - |
590 | 574 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
591 | | - self.visited = set() |
| 575 | + subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] |
| 576 | + processed_nodes: set[torch.fx.Node] = set() |
592 | 577 | for node in graph_module.graph.nodes: |
593 | | - sg = self.Subgraph() |
594 | | - self.start_search(node, sg) |
595 | | - if self.is_valid_subgraph(sg): |
596 | | - logging.debug(f"Found valid subgraph: {sg}") |
597 | | - self.handle_subgraph(graph_module, sg) |
| 578 | + if node.target != exir_ops.edge.aten.permute_copy.default: |
| 579 | + continue |
598 | 580 |
|
599 | | - result = super().call(graph_module) |
600 | | - return result |
| 581 | + start_permute = self.get_permutation(node) |
| 582 | + # Expected end permutation for the subgraph. |
| 583 | + end_permute = [start_permute.index(i) for i in range(len(start_permute))] |
601 | 584 |
|
602 | | - def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None: |
603 | | - assert mean_dim.target == exir_ops.edge.aten.mean.dim |
604 | | - args = list(mean_dim.args) |
605 | | - args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])] |
606 | | - mean_dim.args = tuple(args) |
| 585 | + for user in node.users: |
| 586 | + if user.target not in self.permutable_ops: |
| 587 | + continue |
| 588 | + # Create a separate subgraph for each user since there may be cases |
| 589 | + # where only a portion of the users are permutable. |
| 590 | + subgraph = self.Subgraph(start_permute, end_permute) |
| 591 | + if self.visit(user, subgraph, processed_nodes): |
| 592 | + subgraphs_found.append(subgraph) |
| 593 | + for node in subgraph.nodes: |
| 594 | + processed_nodes.add(node) |
607 | 595 |
|
608 | | - def handle_cat(self, cat: torch.fx.Node) -> None: |
609 | | - assert cat.target == exir_ops.edge.aten.cat.default |
610 | | - args = list(cat.args) |
611 | | - args[1] = self.to_NCHW[cast(int, args[1])] |
612 | | - cat.args = tuple(args) |
| 596 | + for subgraph in subgraphs_found: |
| 597 | + self.permute_subgraph(subgraph) |
613 | 598 |
|
614 | | - def is_valid_subgraph(self, sg: Subgraph) -> bool: |
615 | | - return ( |
616 | | - sg.is_valid |
617 | | - and len(sg.start_permutes) > 0 |
618 | | - and len(sg.end_permutes) > 0 |
619 | | - and len(sg.intermediate_nodes) > 0 |
620 | | - ) |
| 599 | + graph_module.graph.eliminate_dead_code() |
| 600 | + graph_module.recompile() |
621 | 601 |
|
622 | | - def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None: |
623 | | - for permute in itertools.chain(sg.start_permutes, sg.end_permutes): |
624 | | - permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6] |
| 602 | + return super().call(graph_module) |
625 | 603 |
|
626 | | - for node in sg.intermediate_nodes: |
627 | | - if node.target in self.special_handling: |
628 | | - self.special_handling[node.target](node) |
| 604 | + def visit( |
| 605 | + self, |
| 606 | + node: torch.fx.Node, |
| 607 | + subgraph: Subgraph, |
| 608 | + processed_nodes: set[torch.fx.Node], |
| 609 | + ) -> bool: |
| 610 | + if node in subgraph.nodes: |
| 611 | + return True |
| 612 | + if node in processed_nodes or not self.is_node_permutable(node): |
| 613 | + return False |
| 614 | + subgraph.nodes.add(node) |
| 615 | + |
| 616 | + # Traverse downstream: |
| 617 | + for user in node.users: |
| 618 | + # Output should either go to a matching permute or another permutable op. |
| 619 | + if user.target == exir_ops.edge.aten.permute_copy.default: |
| 620 | + if self.get_permutation(user) != subgraph.end_permute: |
| 621 | + return False |
| 622 | + subgraph.edges_out.add((node, user)) |
| 623 | + elif not self.visit(user, subgraph, processed_nodes): |
| 624 | + return False |
629 | 625 |
|
630 | | - graph_module.recompile() |
631 | | - graph_module.graph.eliminate_dead_code() |
| 626 | + # Traverse upstream: |
| 627 | + for inp in node.all_input_nodes: |
| 628 | + # Input should either come from a matching permute or another permutable op. |
| 629 | + if inp.target == exir_ops.edge.aten.permute_copy.default: |
| 630 | + if self.get_permutation(inp) != subgraph.start_permute: |
| 631 | + return False |
| 632 | + subgraph.edges_in.add((inp, node)) |
| 633 | + elif not self.visit(inp, subgraph, processed_nodes): |
| 634 | + return False |
632 | 635 |
|
633 | | - def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None: |
634 | | - if node in self.visited: |
635 | | - return |
| 636 | + return True |
636 | 637 |
|
637 | | - if self.is_starting_permute(node): |
638 | | - sg.start_permutes.add(node) |
639 | | - self.visited.add(node) |
640 | | - for user in node.users: |
641 | | - self.search_down(user, sg) |
642 | | - |
643 | | - def search_up(self, node: object, sg: Subgraph) -> None: |
644 | | - # non-nodes can be ignored. These would be arguments like integers or lists |
645 | | - # of integers, which don't affect the subgraph validity or inclusion set. |
646 | | - if not isinstance(node, torch.fx.Node): |
647 | | - return |
648 | | - |
649 | | - if node.op == "placeholder": |
650 | | - # If we reach a placeholder or other terminal node without encountering |
651 | | - # a start permute, then the subgraph is invalid. |
652 | | - # This could be because in the add(x, y) case where x is permuted and |
653 | | - # y is a graph input, we can't remove the permute on x because it might |
654 | | - # become two different shapes that don't broadcast together. |
655 | | - # TODO: Adding a permute on y could be the more optimal solution, |
656 | | - # but perhaps not in all cases, say if x is small and y is very large. |
657 | | - # This transform prefers to be safe over optimal for now. |
658 | | - sg.is_valid = False |
659 | | - return |
660 | | - |
661 | | - if node in self.visited: |
662 | | - return |
663 | | - |
664 | | - self.visited.add(node) |
665 | | - |
666 | | - if self.is_starting_permute(node): |
667 | | - sg.start_permutes.add(node) |
668 | | - for user in node.users: |
669 | | - self.search_down(user, sg) |
670 | | - else: |
671 | | - self.traverse_intermediate_node(node, sg) |
672 | | - |
673 | | - def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None: |
674 | | - if node in self.visited or self.is_starting_permute(node): |
675 | | - return |
676 | | - |
677 | | - self.visited.add(node) |
678 | | - |
679 | | - if self.is_ending_permute(node): |
680 | | - sg.end_permutes.add(node) |
681 | | - for arg in node.args: |
682 | | - if isinstance(arg, list): |
683 | | - for elem in arg: |
684 | | - self.search_up(elem, sg) |
685 | | - else: |
686 | | - self.search_up(arg, sg) |
| 638 | + def is_node_permutable(self, node: torch.fx.Node) -> bool: |
| 639 | + if node.target not in self.permutable_ops: |
| 640 | + return False |
| 641 | + if node.target == exir_ops.edge.aten.mean.dim: |
| 642 | + # keepdim should be True. |
| 643 | + if len(node.args) >= 3: |
| 644 | + if not node.args[2]: |
| 645 | + return False |
| 646 | + elif "keepdim" in node.kwargs: |
| 647 | + if not node.kwargs["keepdim"]: |
| 648 | + return False |
| 649 | + else: |
| 650 | + # Default keepdim is False. |
| 651 | + return False |
| 652 | + return True |
| 653 | + |
| 654 | + def permute_subgraph(self, subgraph: Subgraph) -> None: |
| 655 | + # Skip incoming permutes. |
| 656 | + for inp, out in subgraph.edges_in: |
| 657 | + assert inp.target == exir_ops.edge.aten.permute_copy.default |
| 658 | + if len(inp.args) >= 1: |
| 659 | + out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) |
| 660 | + else: |
| 661 | + out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) |
| 662 | + |
| 663 | + # Skip outgoing permutes. |
| 664 | + for inp, out in subgraph.edges_out: |
| 665 | + assert out.target == exir_ops.edge.aten.permute_copy.default |
| 666 | + out.replace_all_uses_with(inp) |
| 667 | + |
| 668 | + # Handle dimension related node arguments. |
| 669 | + for node in subgraph.nodes: |
| 670 | + if node.target == exir_ops.edge.aten.cat.default: |
| 671 | + self.update_cat(node, subgraph.start_permute) |
| 672 | + elif node.target == exir_ops.edge.aten.mean.dim: |
| 673 | + self.update_mean_dim(node, subgraph.start_permute) |
| 674 | + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: |
| 675 | + self.update_slice_copy(node, subgraph.start_permute) |
| 676 | + |
| 677 | + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 678 | + if len(node.args) >= 2: |
| 679 | + node.update_arg(1, start_permute[cast(int, node.args[1])]) |
| 680 | + elif "dim" in node.kwargs: |
| 681 | + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) |
687 | 682 | else: |
688 | | - self.traverse_intermediate_node(node, sg) |
689 | | - |
690 | | - def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: |
691 | | - if node.target in self.elementwise_ops: |
692 | | - sg.intermediate_nodes.add(node) |
693 | | - for arg in node.args: |
694 | | - if isinstance(arg, list): |
695 | | - for elem in arg: |
696 | | - self.search_up(elem, sg) |
697 | | - else: |
698 | | - self.search_up(arg, sg) |
699 | | - |
700 | | - for user in node.users: |
701 | | - self.search_down(user, sg) |
| 683 | + # Default cat dim is 0. |
| 684 | + node.update_kwarg("dim", start_permute[0]) |
702 | 685 |
|
703 | | - else: |
704 | | - sg.is_valid = False |
705 | | - |
706 | | - def is_starting_permute(self, node: torch.fx.Node) -> bool: |
707 | | - return self.is_boundary_permute(node, self.to_NCHW) |
708 | | - |
709 | | - def is_ending_permute(self, node: torch.fx.Node) -> bool: |
710 | | - return self.is_boundary_permute(node, self.to_NHWC) |
711 | | - |
712 | | - @staticmethod |
713 | | - def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool: |
714 | | - permute_dims = list(permute_dims) |
715 | | - if node.target == exir_ops.edge.aten.permute_copy.default: |
716 | | - return cast(list[int], node.args[1]) == permute_dims |
717 | | - elif node.target == exir_ops.edge.aten.view_copy.default: |
718 | | - # If there's a view node, check if it's swapping two dimensions and |
719 | | - # not splitting any others from the input shape. |
720 | | - inp = node.args[0] |
721 | | - if not isinstance(inp, torch.fx.Node): |
722 | | - return False |
723 | | - input_shape = inp.meta["val"].shape |
724 | | - output_shape = node.args[1] |
725 | | - assert isinstance(output_shape, (tuple, list)) |
726 | | - # If the shapes are equal in length, no dimension is being split or |
727 | | - # grouped. Then check if a permute of the input shape results in the output shape. |
728 | | - return ( |
729 | | - len(input_shape) == len(output_shape) |
730 | | - and len(input_shape) == len(permute_dims) |
731 | | - and RemovePermutesAroundElementwiseOps.permute_shape( |
732 | | - input_shape, permute_dims |
733 | | - ) |
734 | | - == output_shape |
| 686 | + def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 687 | + if len(node.args) >= 2: |
| 688 | + node.update_arg( |
| 689 | + 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] |
735 | 690 | ) |
736 | 691 | else: |
737 | | - return False |
| 692 | + node.update_kwarg( |
| 693 | + "dim", |
| 694 | + [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], |
| 695 | + ) |
738 | 696 |
|
739 | | - @staticmethod |
740 | | - def permute_shape( |
741 | | - shape: Union[List[int], torch.Size], permute_dims: Iterable[int] |
742 | | - ) -> List[int]: |
743 | | - permute_dims = list(permute_dims) |
744 | | - assert len(shape) == len(permute_dims) |
745 | | - return [shape[p] for p in permute_dims] |
| 697 | + def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 698 | + if len(node.args) >= 2: |
| 699 | + node.update_arg(1, start_permute[cast(int, node.args[1])]) |
| 700 | + else: |
| 701 | + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) |
| 702 | + |
| 703 | + def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: |
| 704 | + assert permute_node.target == exir_ops.edge.aten.permute_copy.default |
| 705 | + if len(permute_node.args) >= 2: |
| 706 | + return cast(list[int], permute_node.args[1]) |
| 707 | + assert "dim" in permute_node.kwargs |
| 708 | + return cast(list[int], permute_node.kwargs["dim"]) |
746 | 709 |
|
747 | 710 |
|
748 | 711 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
|
0 commit comments