Skip to content

Commit e94a9e1

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add quantization and partitioner flow in the qualcomm doc (pytorch#12387)
Summary: Add a session to describe how to lower a model to HTP, including quantization step. Differential Revision: D78117959
1 parent 4e956e6 commit e94a9e1

File tree

3 files changed

+134
-14
lines changed

3 files changed

+134
-14
lines changed

docs/source/backends-qualcomm.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,94 @@ The model, inputs, and output location are passed to `qnn_executorch_runner` by
354354

355355
Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.
356356

357+
## How to Support a Custom Model in HTP Backend
358+
359+
### Step-by-Step Implementation Guide
360+
361+
Please reference [the simple example](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/scripts/export_example.py) and [more compilated examples](https://github.com/pytorch/executorch/tree/main/examples/qualcomm/scripts) for reference
362+
#### Step 1: Prepare Your Model
363+
```python
364+
import torch
365+
366+
# Initialize your custom model
367+
model = YourModelClass().eval() # Your custom PyTorch model
368+
369+
# Create example inputs (adjust shape as needed)
370+
example_inputs = (torch.randn(1, 3, 224, 224),) # Example input tensor
371+
```
372+
373+
#### Step 2: [Optional] Quantize Your Model
374+
Choose between quantization approaches, post training quantization (PTQ) or quantization aware training (QAT):
375+
```python
376+
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
377+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e, convert_pt2e
378+
379+
quantizer = QnnQuantizer()
380+
m = torch.export.export(model, example_inputs, strict=True).module()
381+
382+
# PTQ (Post-Training Quantization)
383+
if quantization_type == "ptq":
384+
prepared_model = prepare_pt2e(m, quantizer)
385+
# Calibration loop would go here
386+
prepared_model(*example_inputs)
387+
388+
# QAT (Quantization-Aware Training)
389+
elif quantization_type == "qat":
390+
prepared_model = prepare_qat_pt2e(m, quantizer)
391+
# Training loop would go here
392+
for _ in range(training_steps):
393+
prepared_model(*example_inputs)
394+
395+
# Convert to quantized model
396+
quantized_model = convert_pt2e(prepared_model)
397+
```
398+
399+
#### Step 3: Configure Compile Specs
400+
During this step, you will need to specify the target SoC, data type, and other QNN compiler spec.
401+
```python
402+
from executorch.backends.qualcomm.compiler import (
403+
generate_qnn_executorch_compiler_spec,
404+
generate_htp_compiler_spec,
405+
)
406+
from executorch.backends.qualcomm.utils.utils import QcomChipset
407+
408+
# HTP Compiler Configuration
409+
backend_options = generate_htp_compiler_spec(
410+
use_fp16=not quantized, # False for quantized models
411+
)
412+
413+
# QNN Compiler Spec
414+
compile_spec = generate_qnn_executorch_compiler_spec(
415+
soc_model=QcomChipset.SM8650, # Your target SoC
416+
backend_options=backend_options,
417+
)
418+
```
419+
#### Step 4: Lower and Export the Model
420+
```python
421+
from executorch.backends.qualcomm.partition.qnn_partitioner import (
422+
to_edge_transform_and_lower_to_qnn,
423+
)
424+
from executorch.exir import ExecutorchBackendConfig
425+
426+
# Lower to QNN backend
427+
delegated_program = to_edge_transform_and_lower_to_qnn(
428+
quantized_model if quantized else model,
429+
example_inputs,
430+
compile_spec
431+
)
432+
433+
# Export to ExecuTorch format
434+
executorch_program = delegated_program.to_executorch(
435+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
436+
)
437+
438+
# Save the compiled model
439+
model_name = "custom_model_qnn.pte"
440+
with open(model_name, "wb") as f:
441+
f.write(executorch_program.buffer)
442+
print(f"Model successfully exported to {model_name}")
443+
```
444+
357445
## What is coming?
358446

359447
- Improve the performance for llama3-8B-Instruct and support batch prefill.

docs/source/quantization-overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Not all quantization options are supported by all backends. Consult backend-spec
3030

3131
* [XNNPACK quantization](backends-xnnpack.md#quantization)
3232
* [CoreML quantization](backends-coreml.md#quantization)
33+
* [QNN quantization](backends-qualcomm.md#step-2-optional-quantize-your-model)
3334

3435

3536

examples/qualcomm/scripts/export_example.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import torch
66
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
7-
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
87
from executorch.backends.qualcomm.utils.utils import (
98
generate_htp_compiler_spec,
109
generate_qnn_executorch_compiler_spec,
10+
get_soc_to_chipset_map,
1111
to_edge_transform_and_lower_to_qnn,
1212
)
1313
from executorch.devtools import generate_etrecord
@@ -16,8 +16,11 @@
1616
from executorch.exir.capture._config import ExecutorchBackendConfig
1717
from executorch.extension.export_util.utils import save_pte_program
1818

19-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
20-
19+
from torchao.quantization.pt2e.quantize_pt2e import (
20+
convert_pt2e,
21+
prepare_pt2e,
22+
prepare_qat_pt2e,
23+
)
2124

2225
def main() -> None:
2326
parser = argparse.ArgumentParser()
@@ -43,6 +46,20 @@ def main() -> None:
4346
help="The folder to store the exported program",
4447
)
4548

49+
parser.add_argument(
50+
"--soc",
51+
type=str,
52+
default="SM8650",
53+
help="Specify the SoC model.",
54+
)
55+
56+
parser.add_argument(
57+
"-q",
58+
"--quantization",
59+
choices=["ptq", "qat"],
60+
help="Run post-traininig quantization.",
61+
)
62+
4663
args = parser.parse_args()
4764

4865
if args.model_name not in MODEL_NAME_TO_MODEL:
@@ -51,27 +68,41 @@ def main() -> None:
5168
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5269
)
5370

71+
# Get model and example inputs
5472
model, example_inputs, _, _ = EagerModelFactory.create_model(
5573
*MODEL_NAME_TO_MODEL[args.model_name]
5674
)
5775

5876
# Get quantizer
59-
quantizer = QnnQuantizer()
60-
61-
# Typical pytorch 2.0 quantization flow
62-
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
63-
m = prepare_pt2e(m, quantizer)
64-
# Calibration
65-
m(*example_inputs)
66-
# Get the quantized model
67-
m = convert_pt2e(m)
77+
if args.quantization:
78+
print("Quantizing model...")
79+
# It is the model quantization path
80+
quantizer = QnnQuantizer()
81+
# Typical pytorch 2.0 quantization flow
82+
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
83+
if args.quantization == "qat":
84+
m = prepare_qat_pt2e(m, quantizer)
85+
# Training loop
86+
m(*example_inputs)
87+
elif args.quantization == "ptq":
88+
m = prepare_pt2e(m, quantizer)
89+
# Calibration
90+
m(*example_inputs)
91+
else:
92+
raise RuntimeError(f"Unknown quantization type {args.quantization}")
93+
# Get the quantized model
94+
m = convert_pt2e(m)
95+
else:
96+
# It is the fp model path
97+
m = model
6898

6999
# Capture program for edge IR and delegate to QNN backend
100+
use_fp16 = True if args.quantization is None else False
70101
backend_options = generate_htp_compiler_spec(
71-
use_fp16=False,
102+
use_fp16=use_fp16,
72103
)
73104
compile_spec = generate_qnn_executorch_compiler_spec(
74-
soc_model=QcomChipset.SM8550,
105+
soc_model=get_soc_to_chipset_map()[args.soc],
75106
backend_options=backend_options,
76107
)
77108
delegated_program = to_edge_transform_and_lower_to_qnn(

0 commit comments

Comments
 (0)