Skip to content

Commit 1247545

Browse files
authored
Qualcomm AI Engine Direct - qat proto
* [qat proto] - Add qat proto - Add Unit test test_qnn_backend_linear_qat - Test command ```bash python backends/qualcomm/tests/test_qnn_delegate.py -H $HOST -s $DEVICE -b $build-android/ -m "SM8650" -r $EXECUTORCH_ROOT -k TestQNNQuantizedOperator.test_qnn_backend_linear_qat ``` * [Fix lint] --------- Co-authored-by: Joey Tsai <[email protected]> Pull Request resolved: #6222
1 parent 3ea8538 commit 1247545

File tree

4 files changed

+97
-1
lines changed

4 files changed

+97
-1
lines changed

backends/qualcomm/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
get_16a4w_qnn_ptq_config,
2727
get_16a8w_qnn_ptq_config,
2828
get_default_16bit_qnn_ptq_config,
29+
get_default_8bit_qat_proto,
2930
get_default_8bit_qnn_ptq_config,
3031
get_ptq_per_channel_quant_config,
3132
OP_ANNOTATOR,
@@ -39,6 +40,7 @@
3940
"get_16a8w_qnn_ptq_config",
4041
"get_default_16bit_qnn_ptq_config",
4142
"get_default_8bit_qnn_ptq_config",
43+
"get_default_8bit_qat_proto",
4244
]
4345

4446

backends/qualcomm/quantizer/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from torch._ops import OpOverload
1616
from torch._subclasses import FakeTensor
1717

18+
from torch.ao.quantization.fake_quantize import (
19+
default_fake_quant,
20+
FusedMovingAvgObsFakeQuantize,
21+
)
22+
1823
from torch.ao.quantization.observer import (
1924
FixedQParamsObserver,
2025
MinMaxObserver,
@@ -179,6 +184,46 @@ def _derive_bias_qparams_fn(
179184
)
180185

181186

187+
def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig:
188+
189+
act_quantization_spec = QuantizationSpec(
190+
dtype=torch.uint8,
191+
qscheme=(
192+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
193+
),
194+
ch_axis=0,
195+
observer_or_fake_quant_ctr=default_fake_quant,
196+
)
197+
198+
weight_quantization_spec = QuantizationSpec(
199+
dtype=torch.int8,
200+
quant_min=torch.iinfo(torch.int8).min + 1,
201+
quant_max=torch.iinfo(torch.int8).max,
202+
qscheme=torch.per_tensor_symmetric,
203+
ch_axis=0,
204+
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
205+
observer=MovingAverageMinMaxObserver
206+
),
207+
)
208+
209+
bias_quantization_spec = QuantizationSpec(
210+
dtype=torch.int32,
211+
quant_min=torch.iinfo(torch.int32).min,
212+
quant_max=torch.iinfo(torch.int32).max,
213+
qscheme=torch.per_tensor_symmetric,
214+
observer_or_fake_quant_ctr=default_fake_quant,
215+
)
216+
217+
quantization_config = QuantizationConfig(
218+
input_activation=act_quantization_spec,
219+
output_activation=act_quantization_spec,
220+
weight=weight_quantization_spec,
221+
bias=bias_quantization_spec,
222+
)
223+
224+
return quantization_config
225+
226+
182227
def get_default_8bit_qnn_ptq_config(
183228
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
184229
) -> QuantizationConfig:

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,26 @@ def test_qnn_backend_linear(self):
10421042
module = self.get_qdq_module(module, sample_input)
10431043
self.lower_module_and_test_output(module, sample_input)
10441044

1045+
def test_qnn_backend_linear_qat(self):
1046+
"""
1047+
Prototype to test qat model
1048+
"""
1049+
module = Linear() # noqa: F405
1050+
sample_input = (torch.randn([3, 4]),)
1051+
1052+
module = self.get_prepared_qat_module(module, sample_input)
1053+
1054+
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
1055+
criterion = torch.nn.CrossEntropyLoss()
1056+
output = module(*sample_input)
1057+
loss = criterion(output, module(*sample_input))
1058+
optimizer.zero_grad()
1059+
loss.backward()
1060+
optimizer.step()
1061+
1062+
module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module)
1063+
self.lower_module_and_test_output(module, sample_input)
1064+
10451065
def test_qnn_backend_log_softmax(self):
10461066
module = LogSoftmax() # noqa: F405
10471067
sample_input = (torch.randn([1, 4, 8, 8]),)

backends/qualcomm/tests/utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.qualcomm.quantizer.quantizer import (
2121
get_16a4w_qnn_ptq_config,
2222
get_default_16bit_qnn_ptq_config,
23+
get_default_8bit_qat_proto,
2324
QnnQuantizer,
2425
QuantDtype,
2526
)
@@ -44,7 +45,11 @@
4445
from executorch.exir.pass_base import ExportPass
4546
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4647
from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager
47-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
48+
from torch.ao.quantization.quantize_pt2e import (
49+
convert_pt2e,
50+
prepare_pt2e,
51+
prepare_qat_pt2e,
52+
)
4853

4954

5055
def generate_context_binary(
@@ -426,6 +431,30 @@ def get_qdq_module(
426431
self.assertTrue(nodes.intersection(q_and_dq))
427432
return quantized_module
428433

434+
def get_prepared_qat_module(
435+
self,
436+
module: torch.nn.Module,
437+
inputs: Tuple[torch.Tensor],
438+
is_conv_per_channel: Optional[bool] = True,
439+
is_linear_per_channel: Optional[bool] = False,
440+
custom_quant_annotations: Tuple[Callable] = (),
441+
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
442+
) -> torch.fx.GraphModule:
443+
m = torch.export.export_for_training(module, inputs).module()
444+
445+
quantizer = QnnQuantizer()
446+
quantizer.add_custom_quant_annotations(custom_quant_annotations)
447+
quantizer.set_per_channel_conv_quant(is_conv_per_channel)
448+
quantizer.set_per_channel_linear_quant(is_linear_per_channel)
449+
450+
if quant_dtype == QuantDtype.use_8a8w:
451+
quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto())
452+
else:
453+
raise RuntimeError("Shuld not be here")
454+
455+
prepared = prepare_qat_pt2e(m, quantizer)
456+
return torch.ao.quantization.move_exported_model_to_train(prepared)
457+
429458
def split_graph(self, graph_module: torch.fx.GraphModule, division: int):
430459
class SplitGraph(ExportPass):
431460
"""

0 commit comments

Comments
 (0)