Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions docs/source/backends-qualcomm.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,115 @@ The model, inputs, and output location are passed to `qnn_executorch_runner` by

Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.

## How to Support a Custom Model in HTP Backend

### Step-by-Step Implementation Guide

Please reference [the simple example](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/scripts/export_example.py) and [more compilated examples](https://github.com/pytorch/executorch/tree/main/examples/qualcomm/scripts) for reference
#### Step 1: Prepare Your Model
```python
import torch

# Initialize your custom model
model = YourModelClass().eval() # Your custom PyTorch model

# Create example inputs (adjust shape as needed)
example_inputs = (torch.randn(1, 3, 224, 224),) # Example input tensor
```

#### Step 2: [Optional] Quantize Your Model
Choose between quantization approaches, post training quantization (PTQ) or quantization aware training (QAT):
```python
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is QnnQuantizer configurable? If so, can we document the configuration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e, convert_pt2e

quantizer = QnnQuantizer()
m = torch.export.export(model, example_inputs, strict=True).module()

# PTQ (Post-Training Quantization)
if quantization_type == "ptq":
prepared_model = prepare_pt2e(m, quantizer)
# Calibration loop would go here
prepared_model(*example_inputs)

# QAT (Quantization-Aware Training)
elif quantization_type == "qat":
prepared_model = prepare_qat_pt2e(m, quantizer)
# Training loop would go here
for _ in range(training_steps):
prepared_model(*example_inputs)

# Convert to quantized model
quantized_model = convert_pt2e(prepared_model)
```

The `QNNQuantizer` is configurable, with the default setting being **8a8w**. For advanced users, refer to the [`QnnQuantizer`](https://github.com/pytorch/executorch/blob/main/backends/qualcomm/quantizer/quantizer.py) documentation for details.

##### Supported Quantization Schemes
- **8a8w** (default)
- **16a16w**
- **16a8w**
- **16a4w**
- **16a4w_block**

##### Customization Options
- **Per-node annotation**: Use `custom_quant_annotations`.
- **Per-module (`nn.Module`) annotation**: Use `submodule_qconfig_list`.

##### Additional Features
- **Node exclusion**: Discard specific nodes via `discard_nodes`.
- **Blockwise quantization**: Configure block sizes with `block_size_map`.


For practical examples, see [`test_qnn_delegate.py`](https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_qnn_delegate.py).


#### Step 3: Configure Compile Specs
During this step, you will need to specify the target SoC, data type, and other QNN compiler spec.
```python
from executorch.backends.qualcomm.compiler import (
generate_qnn_executorch_compiler_spec,
generate_htp_compiler_spec,
)
from executorch.backends.qualcomm.utils.utils import QcomChipset

# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
use_fp16=not quantized, # False for quantized models
)

# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650, # Your target SoC
backend_options=backend_options,
)
```
#### Step 4: Lower and Export the Model
```python
from executorch.backends.qualcomm.partition.qnn_partitioner import (
to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import ExecutorchBackendConfig

# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
quantized_model if quantized else model,
example_inputs,
compile_spec
)

# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

# Save the compiled model
model_name = "custom_model_qnn.pte"
with open(model_name, "wb") as f:
f.write(executorch_program.buffer)
print(f"Model successfully exported to {model_name}")
```

## What is coming?

- Improve the performance for llama3-8B-Instruct and support batch prefill.
Expand Down
1 change: 1 addition & 0 deletions docs/source/quantization-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Not all quantization options are supported by all backends. Consult backend-spec

* [XNNPACK quantization](backends-xnnpack.md#quantization)
* [CoreML quantization](backends-coreml.md#quantization)
* [QNN quantization](backends-qualcomm.md#step-2-optional-quantize-your-model)



Expand Down
58 changes: 45 additions & 13 deletions examples/qualcomm/scripts/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import torch
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
get_soc_to_chipset_map,
to_edge_transform_and_lower_to_qnn,
)
from executorch.devtools import generate_etrecord
Expand All @@ -16,7 +16,11 @@
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.extension.export_util.utils import save_pte_program

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)


def main() -> None:
Expand All @@ -43,6 +47,20 @@ def main() -> None:
help="The folder to store the exported program",
)

parser.add_argument(
"--soc",
type=str,
default="SM8650",
help="Specify the SoC model.",
)

parser.add_argument(
"-q",
"--quantization",
choices=["ptq", "qat"],
help="Run post-traininig quantization.",
)

args = parser.parse_args()

if args.model_name not in MODEL_NAME_TO_MODEL:
Expand All @@ -51,27 +69,41 @@ def main() -> None:
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
)

# Get model and example inputs
model, example_inputs, _, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[args.model_name]
)

# Get quantizer
quantizer = QnnQuantizer()

# Typical pytorch 2.0 quantization flow
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibration
m(*example_inputs)
# Get the quantized model
m = convert_pt2e(m)
if args.quantization:
print("Quantizing model...")
# It is the model quantization path
quantizer = QnnQuantizer()
# Typical pytorch 2.0 quantization flow
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
if args.quantization == "qat":
m = prepare_qat_pt2e(m, quantizer)
# Training loop
m(*example_inputs)
elif args.quantization == "ptq":
m = prepare_pt2e(m, quantizer)
# Calibration
m(*example_inputs)
else:
raise RuntimeError(f"Unknown quantization type {args.quantization}")
# Get the quantized model
m = convert_pt2e(m)
else:
# It is the fp model path
m = model

# Capture program for edge IR and delegate to QNN backend
use_fp16 = True if args.quantization is None else False
backend_options = generate_htp_compiler_spec(
use_fp16=False,
use_fp16=use_fp16,
)
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8550,
soc_model=get_soc_to_chipset_map()[args.soc],
backend_options=backend_options,
)
delegated_program = to_edge_transform_and_lower_to_qnn(
Expand Down
Loading