Skip to content

Commit a8fe653

Browse files
authored
enable qat for custom annotation in qnn
Differential Revision: D79705374 Pull Request resolved: #13147
1 parent f7f486d commit a8fe653

File tree

3 files changed

+145
-31
lines changed

3 files changed

+145
-31
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
)
1313
from executorch.backends.qualcomm.quantizer.quantizer import (
1414
get_16a8w_qnn_ptq_config,
15+
get_16a8w_qnn_qat_config,
1516
get_8a8w_qnn_ptq_config,
17+
get_8a8w_qnn_qat_config,
1618
get_ptq_per_channel_quant_config,
19+
get_qat_per_channel_quant_config,
1720
QuantizationConfig,
1821
)
1922
from executorch.exir.dialects._ops import ops as exir_ops
@@ -154,7 +157,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
154157

155158

156159
def annotate_matmul_16a8w( # noqa: C901
157-
gm: torch.fx.GraphModule, annotate_conv=True
160+
gm: torch.fx.GraphModule,
161+
annotate_conv=True,
162+
is_qat=False,
158163
) -> None:
159164
"""
160165
This function is specific for matmul op 16a8w.
@@ -238,7 +243,6 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
238243
def annotate_single_in_single_out(
239244
node: Node, quantization_config: QuantizationConfig
240245
) -> None:
241-
242246
input_qspec_map = {}
243247
input_act = node.args[0]
244248
input_qspec_map[input_act] = quantization_config.input_activation
@@ -252,7 +256,6 @@ def annotate_single_in_single_out(
252256
def annotate_single_in_share_out(
253257
node: Node, quantization_config: QuantizationConfig
254258
) -> None:
255-
256259
input_qspec_map = {}
257260
input_act = node.args[0]
258261
input_qspec_map[input_act] = quantization_config.input_activation
@@ -283,16 +286,27 @@ def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
283286
_annotated=True,
284287
)
285288

286-
def annotate_matmul_input1(node: Node):
287-
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
288-
act_symmetric=True, act_observer=MinMaxObserver
289-
)
290-
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
291-
act_dtype=torch.uint8,
292-
weight_dtype=torch.int4,
293-
act_observer=MinMaxObserver,
294-
act_symmetric=True,
295-
)
289+
def annotate_matmul_input1(node: Node, is_qat: str):
290+
if is_qat:
291+
quantization_config_8a8w = get_8a8w_qnn_qat_config(
292+
act_symmetric=True, act_observer=MinMaxObserver
293+
)
294+
quantization_config_8a4w_per_channel = get_qat_per_channel_quant_config(
295+
act_dtype=torch.uint8,
296+
weight_dtype=torch.int4,
297+
act_observer=MinMaxObserver,
298+
act_symmetric=True,
299+
)
300+
else:
301+
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
302+
act_symmetric=True, act_observer=MinMaxObserver
303+
)
304+
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
305+
act_dtype=torch.uint8,
306+
weight_dtype=torch.int4,
307+
act_observer=MinMaxObserver,
308+
act_symmetric=True,
309+
)
296310
while isinstance(node, Node) and node.op == "call_function":
297311
if node.target in [
298312
torch.ops.aten.permute.default,
@@ -330,12 +344,19 @@ def annotate_matmul_input1(node: Node):
330344
print(f"The node ({node}) is not expected in the input1 of the matmul")
331345
node = node.args[0]
332346

333-
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
347+
if is_qat:
348+
quantization_config_16a8w = get_16a8w_qnn_qat_config(
349+
act_observer=MinMaxObserver
350+
)
351+
else:
352+
quantization_config_16a8w = get_16a8w_qnn_ptq_config(
353+
act_observer=MinMaxObserver
354+
)
334355

335356
for node in gm.graph.nodes:
336357
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
337358
annotate_matmul(node, quantization_config_16a8w)
338-
annotate_matmul_input1(node.args[1])
359+
annotate_matmul_input1(node.args[1], is_qat=is_qat)
339360

340361

341362
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

backends/qualcomm/quantizer/qconfig.py

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config(
187187
return quantization_config
188188

189189

190+
def get_16a8w_qnn_qat_config(
191+
act_observer=MovingAverageMinMaxObserver,
192+
) -> QuantizationConfig:
193+
extra_args: Dict[str, Any] = {"eps": 2**-20}
194+
act_fake_quant_ctr = FakeQuantize.with_args(
195+
dtype=torch.int32,
196+
quant_min=torch.iinfo(torch.uint16).min,
197+
quant_max=torch.iinfo(torch.uint16).max,
198+
qscheme=torch.per_tensor_affine,
199+
reduce_range=True,
200+
observer=act_observer.with_args(**extra_args),
201+
)
202+
act_quantization_spec = QuantizationSpec(
203+
dtype=torch.int32,
204+
quant_min=torch.iinfo(torch.uint16).min,
205+
quant_max=torch.iinfo(torch.uint16).max,
206+
qscheme=torch.per_tensor_affine,
207+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
208+
)
209+
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
210+
dtype=torch.int8,
211+
quant_min=torch.iinfo(torch.int8).min + 1,
212+
quant_max=torch.iinfo(torch.int8).max,
213+
qscheme=torch.per_tensor_symmetric,
214+
reduce_range=True,
215+
observer=MovingAverageMinMaxObserver,
216+
)
217+
weight_quantization_spec = QuantizationSpec(
218+
dtype=torch.int8,
219+
quant_min=torch.iinfo(torch.int8).min + 1,
220+
quant_max=torch.iinfo(torch.int8).max,
221+
qscheme=torch.per_tensor_symmetric,
222+
ch_axis=0,
223+
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
224+
)
225+
bias_fake_quant_ctr = FakeQuantize.with_args(
226+
dtype=torch.int32,
227+
quant_min=torch.iinfo(torch.int32).min,
228+
quant_max=torch.iinfo(torch.int32).max,
229+
qscheme=torch.per_tensor_symmetric,
230+
observer=MovingAverageMinMaxObserver,
231+
)
232+
bias_quantization_spec = QuantizationSpec(
233+
dtype=torch.int32,
234+
quant_min=torch.iinfo(torch.int32).min,
235+
quant_max=torch.iinfo(torch.int32).max,
236+
qscheme=torch.per_tensor_symmetric,
237+
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
238+
)
239+
quantization_config = QuantizationConfig(
240+
input_activation=act_quantization_spec,
241+
output_activation=act_quantization_spec,
242+
weight=weight_quantization_spec,
243+
bias=bias_quantization_spec,
244+
)
245+
246+
return quantization_config
247+
248+
190249
def get_16a16w_qnn_ptq_config(
191250
act_observer=MovingAverageMinMaxObserver,
192251
) -> QuantizationConfig:
@@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config(
459518
act_dtype=torch.uint8,
460519
weight_dtype=torch.int8,
461520
act_observer=MovingAverageMinMaxObserver,
521+
act_symmetric=False,
462522
) -> QuantizationConfig:
463523
supported_act_types = {
464524
torch.uint8,
@@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config(
476536
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
477537

478538
# torch does not support uint16 quantization, use int32 to bypass
479-
act_fake_quant_ctr = FakeQuantize.with_args(
480-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
481-
quant_min=torch.iinfo(act_dtype).min,
482-
quant_max=torch.iinfo(act_dtype).max,
483-
qscheme=torch.per_tensor_affine,
484-
reduce_range=True,
485-
observer=act_observer,
486-
)
487-
act_quantization_spec = QuantizationSpec(
488-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
489-
quant_min=torch.iinfo(act_dtype).min,
490-
quant_max=torch.iinfo(act_dtype).max,
491-
qscheme=torch.per_tensor_affine,
492-
observer_or_fake_quant_ctr=act_fake_quant_ctr,
493-
)
539+
if act_symmetric:
540+
# If zero_point is 128, htp can do optimizations.
541+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
542+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
543+
act_fake_quant_ctr = FakeQuantize.with_args(
544+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
545+
qscheme=torch.per_tensor_symmetric,
546+
reduce_range=True,
547+
observer=act_observer,
548+
)
549+
act_quantization_spec = QuantizationSpec(
550+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
551+
qscheme=torch.per_tensor_symmetric,
552+
ch_axis=0,
553+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554+
)
555+
else:
556+
act_fake_quant_ctr = FakeQuantize.with_args(
557+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558+
quant_min=torch.iinfo(act_dtype).min,
559+
quant_max=torch.iinfo(act_dtype).max,
560+
qscheme=torch.per_tensor_affine,
561+
reduce_range=True,
562+
observer=act_observer,
563+
)
564+
act_quantization_spec = QuantizationSpec(
565+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
566+
quant_min=torch.iinfo(act_dtype).min,
567+
quant_max=torch.iinfo(act_dtype).max,
568+
qscheme=torch.per_tensor_affine,
569+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
570+
)
494571

495572
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
496573
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
@@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config(
513590
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
514591
)
515592

516-
bias_quantization_spec = _derived_bias_quant_spec
593+
bias_fake_quant_ctr = FakeQuantize.with_args(
594+
dtype=torch.int32,
595+
quant_min=torch.iinfo(torch.int32).min,
596+
quant_max=torch.iinfo(torch.int32).max,
597+
qscheme=torch.per_tensor_symmetric,
598+
reduce_range=True,
599+
observer=MovingAverageMinMaxObserver,
600+
)
601+
bias_quantization_spec = QuantizationSpec(
602+
dtype=torch.int32,
603+
quant_min=torch.iinfo(torch.int32).min,
604+
quant_max=torch.iinfo(torch.int32).max,
605+
qscheme=torch.per_tensor_symmetric,
606+
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
607+
)
517608

518609
quantization_config = QuantizationConfig(
519610
input_activation=act_quantization_spec,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_16a4w_qnn_ptq_config,
2424
get_16a4w_qnn_qat_config,
2525
get_16a8w_qnn_ptq_config,
26+
get_16a8w_qnn_qat_config,
2627
get_8a8w_qnn_ptq_config,
2728
get_8a8w_qnn_qat_config,
2829
get_ptq_per_block_quant_config,
@@ -39,6 +40,7 @@
3940
"QuantDtype",
4041
"get_16a4w_qnn_ptq_config",
4142
"get_16a8w_qnn_ptq_config",
43+
"get_16a8w_qnn_qat_config",
4244
"get_16a16w_qnn_ptq_config",
4345
"get_8a8w_qnn_ptq_config",
4446
"get_8a8w_qnn_qat_config",

0 commit comments

Comments
 (0)