|
14 | 14 | QuantizationConfig, |
15 | 15 | ) |
16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops |
17 | | -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver |
| 17 | +from torch.ao.quantization.observer import MinMaxObserver |
18 | 18 | from torch.ao.quantization.quantizer import ( |
19 | 19 | QuantizationAnnotation, |
20 | | - QuantizationSpec, |
21 | 20 | SharedQuantizationSpec, |
22 | 21 | ) |
23 | 22 | from torch.fx import Node |
24 | 23 |
|
25 | 24 |
|
26 | | -def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: |
27 | | - def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: |
28 | | - input_qspec_map = {} |
29 | | - input_act = node.args[0] |
30 | | - input_spec = quantization_config.input_activation |
31 | | - input_qspec_map[input_act] = input_spec |
32 | | - |
33 | | - weight = node.args[1] |
34 | | - input_qspec_map[weight] = quantization_config.weight |
35 | | - |
36 | | - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
37 | | - input_qspec_map=input_qspec_map, |
38 | | - output_qspec=quantization_config.output_activation, |
39 | | - _annotated=True, |
40 | | - ) |
41 | | - |
42 | | - quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( |
43 | | - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver |
44 | | - ) |
45 | | - for node in gm.graph.nodes: |
46 | | - if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: |
47 | | - if "nn_module_stack" in node.meta: |
48 | | - module_values_list = list(node.meta["nn_module_stack"].values()) |
49 | | - full_qualified_name = module_values_list[-1][0] |
50 | | - if full_qualified_name == "output.conv": |
51 | | - annotate_conv2d( |
52 | | - node, quantization_config=quantization_config_16a8w_per_channel |
53 | | - ) |
54 | | - |
55 | | - |
56 | | -def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): |
57 | | - for node in gm.graph.nodes: |
58 | | - if node.op == "output": |
59 | | - for index, prefill_output in enumerate(node.args[0]): |
60 | | - kv_quant_attr = kv_quant_attrs[index] |
61 | | - fixed_observer = FixedQParamsObserver.with_args( |
62 | | - scale=kv_quant_attr[0], |
63 | | - zero_point=kv_quant_attr[1], |
64 | | - quant_min=kv_quant_attr[2], |
65 | | - quant_max=kv_quant_attr[3], |
66 | | - dtype=kv_quant_attr[4], |
67 | | - qscheme=torch.torch.per_tensor_affine, |
68 | | - ) |
69 | | - |
70 | | - fixed_output_spec = QuantizationSpec( |
71 | | - quant_min=kv_quant_attr[2], |
72 | | - quant_max=kv_quant_attr[3], |
73 | | - dtype=kv_quant_attr[4], |
74 | | - ch_axis=0, |
75 | | - observer_or_fake_quant_ctr=fixed_observer, |
76 | | - ) |
77 | | - |
78 | | - input_qspec_map = {} |
79 | | - for input in prefill_output.args: |
80 | | - if isinstance(input, Node): |
81 | | - input_qspec_map[input] = fixed_output_spec |
82 | | - |
83 | | - prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
84 | | - input_qspec_map=input_qspec_map, |
85 | | - output_qspec=fixed_output_spec, |
86 | | - _annotated=True, |
87 | | - ) |
88 | | - |
89 | | - |
90 | | -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 |
| 25 | +def annotate_matmul_16a8w( # noqa: C901 |
| 26 | + gm: torch.fx.GraphModule, traverse_input1=True |
| 27 | +) -> None: |
91 | 28 | """ |
92 | 29 | This function is specific for matmul op 16a8w. |
93 | 30 | For k, we will tag such as the below, and |
@@ -205,7 +142,8 @@ def annotate_matmul_input1(node: Node): |
205 | 142 | for node in gm.graph.nodes: |
206 | 143 | if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
207 | 144 | annotate_matmul(node, quantization_config_16a8w) |
208 | | - annotate_matmul_input1(node.args[1]) |
| 145 | + if traverse_input1: |
| 146 | + annotate_matmul_input1(node.args[1]) |
209 | 147 |
|
210 | 148 |
|
211 | 149 | def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 |
|
0 commit comments