Skip to content

Commit 03836c8

Browse files
author
Joey Tsai
committed
[qat proto]
- Add qat proto - Add Unit test test_qnn_backend_linear_qat - Test command ```bash python backends/qualcomm/tests/test_qnn_delegate.py -H $HOST -s $DEVICE -b $build-android/ -m "SM8650" -r $EXECUTORCH_ROOT -k TestQNNQuantizedOperator.test_qnn_backend_linear_qat ```
1 parent 1f2b9aa commit 03836c8

File tree

4 files changed

+93
-1
lines changed

4 files changed

+93
-1
lines changed

backends/qualcomm/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_default_16bit_qnn_ptq_config,
2929
get_default_8bit_qnn_ptq_config,
3030
get_ptq_per_channel_quant_config,
31+
get_default_8bit_qat_proto,
3132
OP_ANNOTATOR,
3233
QuantizationConfig,
3334
)
@@ -39,6 +40,7 @@
3940
"get_16a8w_qnn_ptq_config",
4041
"get_default_16bit_qnn_ptq_config",
4142
"get_default_8bit_qnn_ptq_config",
43+
"get_default_8bit_qat_proto",
4244
]
4345

4446

backends/qualcomm/quantizer/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
from torch._ops import OpOverload
1616
from torch._subclasses import FakeTensor
1717

18+
from torch.ao.quantization.fake_quantize import (
19+
default_fake_quant,
20+
default_per_channel_weight_fake_quant,
21+
FusedMovingAvgObsFakeQuantize,
22+
)
23+
1824
from torch.ao.quantization.observer import (
1925
FixedQParamsObserver,
2026
MinMaxObserver,
2127
MovingAverageMinMaxObserver,
2228
PerChannelMinMaxObserver,
2329
UniformQuantizationObserverBase,
30+
MovingAveragePerChannelMinMaxObserver,
2431
)
2532

2633
from torch.ao.quantization.quantizer import (
@@ -179,6 +186,44 @@ def _derive_bias_qparams_fn(
179186
)
180187

181188

189+
def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig:
190+
191+
act_quantization_spec = QuantizationSpec(
192+
dtype=torch.uint8,
193+
qscheme=(
194+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
195+
),
196+
ch_axis=0,
197+
observer_or_fake_quant_ctr=default_fake_quant,
198+
)
199+
200+
weight_quantization_spec = QuantizationSpec(
201+
dtype=torch.int8,
202+
quant_min=torch.iinfo(torch.int8).min + 1,
203+
quant_max=torch.iinfo(torch.int8).max,
204+
qscheme=torch.per_tensor_symmetric,
205+
ch_axis=0,
206+
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver),
207+
)
208+
209+
bias_quantization_spec = QuantizationSpec(
210+
dtype=torch.int32,
211+
quant_min=torch.iinfo(torch.int32).min,
212+
quant_max=torch.iinfo(torch.int32).max,
213+
qscheme=torch.per_tensor_symmetric,
214+
observer_or_fake_quant_ctr=default_fake_quant,
215+
)
216+
217+
quantization_config = QuantizationConfig(
218+
input_activation=act_quantization_spec,
219+
output_activation=act_quantization_spec,
220+
weight=weight_quantization_spec,
221+
bias=bias_quantization_spec,
222+
)
223+
224+
return quantization_config
225+
226+
182227
def get_default_8bit_qnn_ptq_config(
183228
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
184229
) -> QuantizationConfig:

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,26 @@ def test_qnn_backend_linear(self):
10451045
module = self.get_qdq_module(module, sample_input)
10461046
self.lower_module_and_test_output(module, sample_input)
10471047

1048+
def test_qnn_backend_linear_qat(self):
1049+
"""
1050+
Prototype to test qat model
1051+
"""
1052+
module = Linear() # noqa: F405
1053+
sample_input = (torch.randn([3, 4]),)
1054+
1055+
module = self.get_prepared_qat_module(module, sample_input)
1056+
1057+
optimizer = torch.optim.SGD(module.parameters(), lr = 0.1)
1058+
criterion = torch.nn.CrossEntropyLoss()
1059+
output = module(*sample_input)
1060+
loss = criterion(output, module(*sample_input))
1061+
optimizer.zero_grad()
1062+
loss.backward()
1063+
optimizer.step()
1064+
1065+
module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module)
1066+
self.lower_module_and_test_output(module, sample_input)
1067+
10481068
def test_qnn_backend_log_softmax(self):
10491069
module = LogSoftmax() # noqa: F405
10501070
sample_input = (torch.randn([1, 4, 8, 8]),)

backends/qualcomm/tests/utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.qualcomm.quantizer.quantizer import (
2121
get_16a4w_qnn_ptq_config,
2222
get_default_16bit_qnn_ptq_config,
23+
get_default_8bit_qat_proto,
2324
QnnQuantizer,
2425
QuantDtype,
2526
)
@@ -44,7 +45,7 @@
4445
from executorch.exir.pass_base import ExportPass
4546
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4647
from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager
47-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
48+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
4849

4950

5051
def generate_context_binary(
@@ -426,6 +427,30 @@ def get_qdq_module(
426427
self.assertTrue(nodes.intersection(q_and_dq))
427428
return quantized_module
428429

430+
def get_prepared_qat_module(
431+
self,
432+
module: torch.nn.Module,
433+
inputs: Tuple[torch.Tensor],
434+
is_conv_per_channel: Optional[bool] = True,
435+
is_linear_per_channel: Optional[bool] = False,
436+
custom_quant_annotations: Tuple[Callable] = (),
437+
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
438+
) -> torch.fx.GraphModule:
439+
m = torch.export.export_for_training(module, inputs).module()
440+
441+
quantizer = QnnQuantizer()
442+
quantizer.add_custom_quant_annotations(custom_quant_annotations)
443+
quantizer.set_per_channel_conv_quant(is_conv_per_channel)
444+
quantizer.set_per_channel_linear_quant(is_linear_per_channel)
445+
446+
if quant_dtype == QuantDtype.use_8a8w:
447+
quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto())
448+
else:
449+
raise RuntimeError("Shuld not be here")
450+
451+
prepared = prepare_qat_pt2e(m, quantizer)
452+
return torch.ao.quantization.move_exported_model_to_train(prepared)
453+
429454
def split_graph(self, graph_module: torch.fx.GraphModule, division: int):
430455
class SplitGraph(ExportPass):
431456
"""

0 commit comments

Comments
 (0)