|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
| 12 | +import operator |
12 | 13 | from copy import copy |
13 | 14 | from typing import Any, Callable, Optional, Union |
14 | 15 |
|
15 | 16 | import torch |
16 | 17 | import torch.fx |
17 | 18 | from torch.ao.quantization.fx.utils import create_getattr_from_value |
18 | 19 | from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ |
| 20 | +from torch.fx.node import map_arg |
| 21 | +from torch.fx.passes.infra.pass_base import PassBase |
| 22 | +from torch.fx.passes.infra.pass_base import PassResult |
19 | 23 | from torch.quantization.fake_quantize import FakeQuantize |
20 | 24 |
|
21 | 25 | import nncf |
@@ -741,3 +745,65 @@ def constraint_fn(node: torch.fx.Node): |
741 | 745 | return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS |
742 | 746 |
|
743 | 747 | constant_fold(model, constraint_fn=constraint_fn) |
| 748 | + |
| 749 | + |
| 750 | +def _duplicate_dq(gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node): |
| 751 | + with gm.graph.inserting_after(dq_node): |
| 752 | + new_node = gm.graph.node_copy(dq_node) |
| 753 | + |
| 754 | + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: |
| 755 | + if n == dq_node: |
| 756 | + return new_node |
| 757 | + else: |
| 758 | + return n |
| 759 | + |
| 760 | + new_args = map_arg(user.args, maybe_replace_node) |
| 761 | + new_kwargs = map_arg(user.kwargs, maybe_replace_node) |
| 762 | + user.args = new_args |
| 763 | + user.kwargs = new_kwargs |
| 764 | + |
| 765 | + |
| 766 | +def _is_sym_size_node(node: torch.fx.Node): |
| 767 | + return ( |
| 768 | + node.op == "call_function" |
| 769 | + and node.target == torch.ops.aten.sym_size.default |
| 770 | + or node.target == torch.ops.aten.sym_numel.default |
| 771 | + or node.target == torch.ops.aten.sym_numel |
| 772 | + or node.target == torch.ops.aten.sym_size |
| 773 | + ) |
| 774 | + |
| 775 | + |
| 776 | +def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]: |
| 777 | + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) |
| 778 | + return node_users |
| 779 | + |
| 780 | + |
| 781 | +class DuplicateDQPassNoAnnotations(PassBase): |
| 782 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 783 | + for node in graph_module.graph.nodes: |
| 784 | + if node.op == "call_function" and node.target in DEQUANTIZE_NODE_TARGETS: |
| 785 | + dq_users = _filter_sym_size_users(node) |
| 786 | + if len(dq_users) <= 1: |
| 787 | + continue |
| 788 | + # Do not duplicate dq for dynamic quantization |
| 789 | + # Pattern: choose_qparam - getitem - q - dq |
| 790 | + q_node = node.args[0] |
| 791 | + if q_node.op == "call_function" and q_node.target in QUANTIZE_NODE_TARGETS: |
| 792 | + getitem_node = q_node.args[1] |
| 793 | + if ( |
| 794 | + isinstance(getitem_node, torch.fx.node.Node) |
| 795 | + and getitem_node.op == "call_function" |
| 796 | + and getitem_node.target == operator.getitem |
| 797 | + ): |
| 798 | + choose_qparam_node = getitem_node.args[0] |
| 799 | + if ( |
| 800 | + isinstance(choose_qparam_node, torch.fx.node.Node) |
| 801 | + and choose_qparam_node.op == "call_function" |
| 802 | + and choose_qparam_node.target == torch.ops.quantized_decomposed.choose_qparams.tensor |
| 803 | + ): |
| 804 | + continue |
| 805 | + for user in dq_users: |
| 806 | + _duplicate_dq(graph_module, node, user) |
| 807 | + graph_module.graph.eliminate_dead_code() |
| 808 | + graph_module.recompile() |
| 809 | + return PassResult(graph_module, True) |
0 commit comments