Skip to content

Commit 413e23f

Browse files
navsudfacebook-github-bot
authored andcommitted
update custom annotations for QAT (#13747)
Summary: Some custom annotations were PTQ only, updating them to support QAT, as it is needed for QAT of static transformer Reviewed By: sxu Differential Revision: D80190466
1 parent 41730fa commit 413e23f

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule):
9292
break
9393

9494

95-
def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
95+
def annotate_linear_16a8w_in_affine_layer(
96+
gm: torch.fx.GraphModule, is_qat: bool = False
97+
) -> None:
9698
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
9799
input_qspec_map = {}
98100
input_act = node.args[0]
@@ -108,9 +110,14 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
108110
_annotated=True,
109111
)
110112

111-
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
112-
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
113-
)
113+
if is_qat:
114+
quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config(
115+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
116+
)
117+
else:
118+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
119+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
120+
)
114121
for node in gm.graph.nodes:
115122
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
116123
if "nn_module_stack" in node.meta:

0 commit comments

Comments
 (0)