|
7 | 7 | import itertools
|
8 | 8 | import logging
|
9 | 9 | import operator
|
10 |
| -from typing import Any, Callable, cast, Dict, List, Optional, Union |
| 10 | +from typing import Any, Callable, cast, Dict, List, Optional, Set, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
@@ -528,6 +528,274 @@ def get_nodes(self, src_partition: SourcePartition) -> List[torch.fx.Node]: # n
|
528 | 528 | )
|
529 | 529 |
|
530 | 530 |
|
| 531 | +class XnnpackPartitioner(Partitioner): |
| 532 | + """ |
| 533 | + Module and Opname based partitioner for FP32 modules/ops listed in |
| 534 | + SUPPORTED_MODULES and SUPPORTED_OPS and statically quantized modules/ops listed in |
| 535 | + SUPPORTED_QUANT_MODULES and SUPPORTED_QUANT_OPS. |
| 536 | + """ |
| 537 | + |
| 538 | + _Q_OPS = [ |
| 539 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 540 | + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, |
| 541 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, |
| 542 | + ] |
| 543 | + |
| 544 | + _DQ_OPS = [ |
| 545 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 546 | + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, |
| 547 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, |
| 548 | + ] |
| 549 | + |
| 550 | + _QPARAM_OPS = [ |
| 551 | + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, |
| 552 | + ] |
| 553 | + |
| 554 | + _QUANT_OPS = _Q_OPS + _DQ_OPS + _QPARAM_OPS |
| 555 | + |
| 556 | + def __init__( |
| 557 | + self, |
| 558 | + *, |
| 559 | + supported_modules: List[Callable] = SUPPORTED_MODULES, |
| 560 | + supported_ops: Optional[List[Callable]] = SUPPORTED_OPS, |
| 561 | + supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES, |
| 562 | + supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS, |
| 563 | + quant: Optional[bool] = None, |
| 564 | + ): |
| 565 | + super().__init__() |
| 566 | + self.supported_modules = set(supported_modules) |
| 567 | + self.supported_ops = set(supported_ops or []) |
| 568 | + self.supported_quant_modules = set(supported_quant_modules) |
| 569 | + supported_quant_ops = supported_quant_ops or [] |
| 570 | + self.supported_quant_ops = set(supported_quant_ops + self._QUANT_OPS) |
| 571 | + |
| 572 | + self.quant = quant |
| 573 | + |
| 574 | + self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) |
| 575 | + self.partition_tags: Dict[str, DelegationSpec] = {} |
| 576 | + |
| 577 | + def get_supported_modules(self, quant: bool) -> Set[Callable]: |
| 578 | + """ |
| 579 | + Get supported modules |
| 580 | + """ |
| 581 | + if quant is True: |
| 582 | + return self.supported_quant_modules |
| 583 | + elif quant is False: |
| 584 | + return self.supported_modules |
| 585 | + else: |
| 586 | + return self.supported_modules | self.supported_quant_modules |
| 587 | + |
| 588 | + def get_supported_ops(self, quant: Optional[bool]) -> Set[Callable]: |
| 589 | + """ |
| 590 | + Get supported ops |
| 591 | + """ |
| 592 | + if quant is True: |
| 593 | + return self.supported_quant_ops |
| 594 | + elif quant is False: |
| 595 | + return self.supported_ops |
| 596 | + else: |
| 597 | + return self.supported_ops | self.supported_quant_ops |
| 598 | + |
| 599 | + @staticmethod |
| 600 | + def check_partitions(partitions: Union[dict, list]) -> bool: |
| 601 | + """ |
| 602 | + Warn users if there aren't any matches |
| 603 | +
|
| 604 | + TODO: convert this into a stronger validation, may need a flag in |
| 605 | + `to_backend()` or partitioner __init__() |
| 606 | + """ |
| 607 | + pl = len(partitions) |
| 608 | + if pl == 0: |
| 609 | + log.warning("Nothing can be partitioned!") |
| 610 | + else: |
| 611 | + log.info(f"Found {pl} subgraphs to be partitioned.") |
| 612 | + return pl != 0 |
| 613 | + |
| 614 | + def get_input_deps( # noqa |
| 615 | + self, input_nodes: List[torch.fx.Node] |
| 616 | + ) -> List[torch.fx.Node]: |
| 617 | + """ |
| 618 | + For each input node, walk up and pull necessary quant/attr nodes in the partition |
| 619 | + """ |
| 620 | + nodes = set() |
| 621 | + for inp in input_nodes: |
| 622 | + if inp.target in self._DQ_OPS: |
| 623 | + # dequant node |
| 624 | + nodes.add(inp) |
| 625 | + |
| 626 | + # possible per_channel scale/zp for the dequant node args{1, 2} |
| 627 | + for i in [1, 2]: |
| 628 | + node = inp.args[i] |
| 629 | + if isinstance(node, torch.fx.Node) and node.op == "get_attr": |
| 630 | + nodes.add(node) |
| 631 | + |
| 632 | + # quant node |
| 633 | + q_prod = inp.args[0] |
| 634 | + assert ( |
| 635 | + isinstance(q_prod, torch.fx.Node) and q_prod.target in self._Q_OPS |
| 636 | + ) |
| 637 | + nodes.add(q_prod) |
| 638 | + |
| 639 | + # possible weight for the quant node arg{0} |
| 640 | + node = q_prod.args[0] |
| 641 | + if isinstance(node, torch.fx.Node) and node.op == "get_attr": |
| 642 | + nodes.add(node) |
| 643 | + |
| 644 | + # possible nodes for quant node args{1, 2} |
| 645 | + for i in [1, 2]: |
| 646 | + node = q_prod.args[i] |
| 647 | + # possible choose_qparam |
| 648 | + if ( |
| 649 | + isinstance(node, torch.fx.Node) |
| 650 | + and node.op == "call_function" |
| 651 | + and node.target == operator.getitem |
| 652 | + ): |
| 653 | + parent = node.args[0] |
| 654 | + if ( |
| 655 | + isinstance(parent, torch.fx.Node) |
| 656 | + and parent.op == "call_function" |
| 657 | + and parent.target in self._QPARAM_OPS |
| 658 | + ): |
| 659 | + nodes.add(node) |
| 660 | + nodes.add(parent) |
| 661 | + |
| 662 | + # possible per_channel scale/zp for the quant node |
| 663 | + elif isinstance(node, torch.fx.Node) and node.op == "get_attr": |
| 664 | + nodes.add(node) |
| 665 | + return list(nodes) |
| 666 | + |
| 667 | + def get_output_deps(self, output_nodes: List[torch.fx.Node]) -> List[torch.fx.Node]: |
| 668 | + """ |
| 669 | + For each output node, check all the users and insert them into the partition if needed |
| 670 | + """ |
| 671 | + nodes = [] |
| 672 | + for output in output_nodes: |
| 673 | + for node in output.users: |
| 674 | + if node.target in self._Q_OPS: |
| 675 | + nodes.append(node) |
| 676 | + users = list(node.users.keys()) |
| 677 | + for dq_user in users: |
| 678 | + assert ( |
| 679 | + dq_user.target in self._DQ_OPS |
| 680 | + ), "Expecting a dq node(s) after a q node, but got target {dq_user.target} for {dq_user} node" |
| 681 | + nodes.append(dq_user) |
| 682 | + return nodes |
| 683 | + |
| 684 | + def get_nodes( |
| 685 | + self, src_partition: SourcePartition, quant: bool |
| 686 | + ) -> List[torch.fx.Node]: |
| 687 | + """ |
| 688 | + Return nodes from the source partition. |
| 689 | + """ |
| 690 | + if quant: |
| 691 | + # Insert quantization ops into src_partition by following the input, output node. |
| 692 | + return ( |
| 693 | + src_partition.nodes |
| 694 | + + self.get_input_deps(src_partition.input_nodes) |
| 695 | + + self.get_output_deps(src_partition.output_nodes) |
| 696 | + ) |
| 697 | + else: |
| 698 | + return src_partition.nodes |
| 699 | + |
| 700 | + def qualify_nodes(self, input_nodes: List[torch.fx.Node]) -> bool: |
| 701 | + """ |
| 702 | + Each node in the module (post decomposition) must satisfy the |
| 703 | + constraints specified for XNNPACK. |
| 704 | +
|
| 705 | + Disqualify the whole module if one of the nodes fails to satisfy. |
| 706 | + """ |
| 707 | + return all( |
| 708 | + XnnpackOperatorSupport.check_constraint(node) for node in input_nodes |
| 709 | + ) |
| 710 | + |
| 711 | + def get_module_partitions( |
| 712 | + self, graph_module: torch.fx.GraphModule, quant: Optional[bool] |
| 713 | + ) -> List[List[torch.fx.Node]]: |
| 714 | + """ |
| 715 | + Get all partitions in the torch.fx.GraphModule for the supported |
| 716 | + modules. |
| 717 | + """ |
| 718 | + |
| 719 | + if quant is None: |
| 720 | + module_partitions = self.get_module_partitions(graph_module, True) |
| 721 | + for node_list in module_partitions: |
| 722 | + for node in node_list: |
| 723 | + node.meta["quant_match"] = True |
| 724 | + fp32_module_partitions = self.get_module_partitions(graph_module, False) |
| 725 | + for node_list in fp32_module_partitions: |
| 726 | + for node in node_list: |
| 727 | + if node.meta.get("quant_match", False): |
| 728 | + break |
| 729 | + else: |
| 730 | + module_partitions.append(node_list) |
| 731 | + for node_list in module_partitions: |
| 732 | + for node in node_list: |
| 733 | + node.meta.pop("quant_match", False) |
| 734 | + return module_partitions |
| 735 | + |
| 736 | + src_partition_dict = get_source_partitions( |
| 737 | + graph_module.graph, self.get_supported_modules(quant) |
| 738 | + ) |
| 739 | + all_partitions = src_partition_dict.values() |
| 740 | + |
| 741 | + module_partitions = [] |
| 742 | + for src_partitions in all_partitions: |
| 743 | + for src_partition in src_partitions: |
| 744 | + partition_nodes = self.get_nodes(src_partition, quant) |
| 745 | + if self.qualify_nodes(partition_nodes): |
| 746 | + module_partitions.append(partition_nodes) |
| 747 | + |
| 748 | + return module_partitions |
| 749 | + |
| 750 | + def generate_partitions( |
| 751 | + self, graph_module: torch.fx.GraphModule, quant: Optional[bool] |
| 752 | + ) -> List[Any]: |
| 753 | + """ |
| 754 | + Generate a list of partitions for an torch.fx.GraphModule. |
| 755 | + Also pass the supported ops to match. |
| 756 | + """ |
| 757 | + matched_module_nodes = self.get_module_partitions(graph_module, quant) |
| 758 | + return generate_partitions_from_list_of_nodes( |
| 759 | + graph_module, |
| 760 | + matched_module_nodes, |
| 761 | + XnnpackOperatorSupport(supported_ops=list(self.get_supported_ops(quant))), |
| 762 | + ) |
| 763 | + |
| 764 | + def tag_nodes(self, partitions: List[Partition]) -> None: |
| 765 | + """ |
| 766 | + Tag each partition in the list with its delegation tag. |
| 767 | + """ |
| 768 | + for partition in partitions: |
| 769 | + # Add delegation tags |
| 770 | + skip = False |
| 771 | + for node in partition.nodes: |
| 772 | + if "delegation_tag" in node.meta: |
| 773 | + skip = True |
| 774 | + if skip: |
| 775 | + continue |
| 776 | + for node in partition.nodes: |
| 777 | + delegation_tag = f"tag{partition.id}" |
| 778 | + node.meta["delegation_tag"] = delegation_tag |
| 779 | + self.partition_tags[delegation_tag] = self.delegation_spec |
| 780 | + |
| 781 | + # override |
| 782 | + def _partition( |
| 783 | + self, graph_module: torch.fx.GraphModule, quant: Optional[bool] |
| 784 | + ) -> torch.fx.GraphModule: |
| 785 | + """ |
| 786 | + Run the partitioner on the given graph module, then tag each partition |
| 787 | + with its delegation tag (and partition id) |
| 788 | + """ |
| 789 | + partitions = self.generate_partitions(graph_module, quant) |
| 790 | + if self.check_partitions(partitions): |
| 791 | + self.tag_nodes(partitions) |
| 792 | + return graph_module |
| 793 | + |
| 794 | + def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 795 | + ret = self._partition(graph_module, self.quant) |
| 796 | + return ret |
| 797 | + |
| 798 | + |
531 | 799 | class XnnpackDynamicallyQuantizedPartitioner(XnnpackQuantizedPartitioner):
|
532 | 800 | def __init__(
|
533 | 801 | self,
|
|
0 commit comments