Skip to content

Commit 0aeb300

Browse files
Sicheng Jiassjia
authored andcommitted
[ET-VK][examples] Create export script for Vulkan examples
ghstack-source-id: df9c51b ghstack-comment-id: 3175671173 Pull-Request: #13286
1 parent 2d4533a commit 0aeb300

File tree

4 files changed

+439
-0
lines changed

4 files changed

+439
-0
lines changed

examples/vulkan/README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Vulkan Delegate Export Examples
2+
3+
This directory contains scripts for exporting models with the Vulkan delegate in ExecuTorch. Vulkan delegation allows you to run your models on devices with Vulkan-capable GPUs, potentially providing significant performance improvements over CPU execution.
4+
5+
## Scripts
6+
7+
- `export.py`: Basic export script for models to use with Vulkan delegate
8+
- `aot_compiler.py`: Advanced export script with quantization support
9+
10+
## Usage
11+
12+
### Basic Export
13+
14+
```bash
15+
python -m executorch.examples.vulkan.export -m <model_name> -o <output_dir>
16+
```
17+
18+
### Export with Quantization (Experimental)
19+
20+
```bash
21+
python -m executorch.examples.vulkan.aot_compiler -m <model_name> -q -o <output_dir>
22+
```
23+
24+
### Dynamic Shape Support
25+
26+
```bash
27+
python -m executorch.examples.vulkan.export -m <model_name> -d -o <output_dir>
28+
```
29+
30+
### Additional Options
31+
32+
- `-s/--strict`: Export with strict mode (default: True)
33+
- `-a/--segment_alignment`: Specify segment alignment in hex (default: 0x1000)
34+
- `-e/--external_constants`: Save constants in external .ptd file (default: False)
35+
- `-r/--etrecord`: Generate and save an ETRecord to the given file location
36+
37+
## Examples
38+
39+
```bash
40+
# Export MobileNetV2 with Vulkan delegate
41+
python -m executorch.examples.vulkan.export -m mobilenet_v2 -o ./exported_models
42+
43+
# Export MobileNetV3 with quantization
44+
python -m executorch.examples.vulkan.aot_compiler -m mobilenet_v3 -q -o ./exported_models
45+
46+
# Export with dynamic shapes
47+
python -m executorch.examples.vulkan.export -m mobilenet_v2 -d -o ./exported_models
48+
49+
# Export with ETRecord for debugging
50+
python -m executorch.examples.vulkan.export -m mobilenet_v2 -r ./records/mobilenet_record.etrecord -o ./exported_models
51+
```
52+
53+
## Supported Operations
54+
55+
The Vulkan delegate supports various operations including:
56+
57+
- Basic arithmetic (add, subtract, multiply, divide)
58+
- Activations (ReLU, Sigmoid, Tanh, etc.)
59+
- Convolutions (Conv1d, Conv2d, ConvTranspose2d)
60+
- Pooling operations (MaxPool2d, AvgPool2d)
61+
- Linear/Fully connected layers
62+
- BatchNorm, GroupNorm
63+
- Various tensor operations (cat, reshape, permute, etc.)
64+
65+
For a complete list of supported operations, refer to the Vulkan delegate implementation in the ExecuTorch codebase.
66+
67+
## Debugging and Optimization
68+
69+
If you encounter issues with Vulkan delegation:
70+
71+
1. Use `-r/--etrecord` to generate an ETRecord for debugging
72+
2. Check if your operations are supported by the Vulkan delegate
73+
3. Ensure your Vulkan drivers are up to date
74+
4. Try using the export script with `--strict False` if strict mode causes issues
75+
76+
## Requirements
77+
78+
- Vulkan runtime libraries (libvulkan.so.1)
79+
- A Vulkan-capable GPU with appropriate drivers
80+
- PyTorch with Vulkan support

examples/vulkan/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

examples/vulkan/aot_compiler.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for compiling models with Vulkan delegation
8+
9+
# pyre-unsafe
10+
11+
import argparse
12+
import logging
13+
14+
import torch
15+
from executorch.backends.transforms.convert_dtype_pass import I64toI32
16+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
17+
from executorch.exir import (
18+
EdgeCompileConfig,
19+
ExecutorchBackendConfig,
20+
to_edge_transform_and_lower,
21+
)
22+
from executorch.extension.export_util.utils import save_pte_program
23+
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
25+
from torchao.quantization.pt2e.quantizer import Quantizer
26+
27+
from ..models import MODEL_NAME_TO_MODEL
28+
from ..models.model_factory import EagerModelFactory
29+
30+
31+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
32+
logging.basicConfig(level=logging.INFO, format=FORMAT)
33+
34+
35+
def quantize_and_lower_module(
36+
model: torch.nn.Module,
37+
sample_inputs,
38+
quantizer: Quantizer,
39+
dynamic_shapes=None,
40+
) -> torch.nn.Module:
41+
"""Quantize a model and lower it with Vulkan delegation"""
42+
compile_options = {}
43+
if dynamic_shapes is not None:
44+
compile_options["require_dynamic_shapes"] = True
45+
46+
edge_compile_config = EdgeCompileConfig(
47+
_skip_dim_order=False, # Proper handling for Vulkan memory format
48+
)
49+
50+
program = torch.export.export_for_training(
51+
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
52+
).module()
53+
54+
program = prepare_pt2e(program, quantizer)
55+
# Calibrate
56+
program(*sample_inputs)
57+
58+
program = convert_pt2e(program)
59+
60+
program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes)
61+
62+
edge_program = to_edge_transform_and_lower(
63+
program,
64+
compile_config=edge_compile_config,
65+
transform_passes=[
66+
I64toI32(edge_compile_config._skip_dim_order),
67+
],
68+
partitioner=[VulkanPartitioner(compile_options)],
69+
)
70+
71+
return edge_program
72+
73+
74+
if __name__ == "__main__":
75+
parser = argparse.ArgumentParser()
76+
parser.add_argument(
77+
"-m",
78+
"--model_name",
79+
required=True,
80+
help=f"Model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
81+
)
82+
parser.add_argument(
83+
"-q",
84+
"--quantize",
85+
action="store_true",
86+
required=False,
87+
default=False,
88+
help="Produce a quantized model. Note: Quantization support may vary by model.",
89+
)
90+
parser.add_argument(
91+
"-d",
92+
"--delegate",
93+
action="store_true",
94+
required=False,
95+
default=True,
96+
help="Produce a Vulkan delegated model",
97+
)
98+
parser.add_argument(
99+
"-y",
100+
"--dynamic",
101+
action="store_true",
102+
required=False,
103+
default=False,
104+
help="Enable dynamic shape support",
105+
)
106+
parser.add_argument(
107+
"-r",
108+
"--etrecord",
109+
required=False,
110+
default="",
111+
help="Generate and save an ETRecord to the given file location",
112+
)
113+
parser.add_argument("-o", "--output_dir", default=".", help="output directory")
114+
115+
args = parser.parse_args()
116+
117+
model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
118+
*MODEL_NAME_TO_MODEL[args.model_name]
119+
)
120+
121+
model = model.eval()
122+
123+
if args.dynamic and dynamic_shapes is None:
124+
logging.warning("Dynamic shapes requested but not available for this model.")
125+
126+
dynamic_shapes_to_use = dynamic_shapes if args.dynamic else None
127+
128+
# Configure Edge compilation
129+
edge_compile_config = EdgeCompileConfig(
130+
_skip_dim_order=False, # Proper handling for Vulkan memory format
131+
_check_ir_validity=True,
132+
)
133+
134+
# Setup compile options
135+
compile_options = {}
136+
if dynamic_shapes_to_use is not None:
137+
compile_options["require_dynamic_shapes"] = True
138+
139+
if args.quantize:
140+
logging.info("Quantization for Vulkan not fully supported yet. Using experimental path.")
141+
try:
142+
# Try to import quantization utilities if available
143+
try:
144+
from ..quantization.utils import get_quantizer_for_model
145+
quantizer = get_quantizer_for_model(args.model_name)
146+
except ImportError:
147+
# If the specific utility isn't available, create a basic quantizer
148+
logging.warning("Quantization utils not found. Using default quantizer.")
149+
from torchao.quantization.pt2e.quantizer import get_default_quantizer
150+
quantizer = get_default_quantizer()
151+
152+
edge = quantize_and_lower_module(
153+
model, example_inputs, quantizer, dynamic_shapes=dynamic_shapes_to_use
154+
)
155+
except (ImportError, NotImplementedError) as e:
156+
logging.error(f"Quantization failed: {e}")
157+
logging.info("Falling back to non-quantized path")
158+
# Export the model using torch.export
159+
program = torch.export.export(
160+
model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True
161+
)
162+
163+
# Transform and lower with Vulkan partitioner
164+
edge = to_edge_transform_and_lower(
165+
program,
166+
compile_config=edge_compile_config,
167+
transform_passes=[
168+
I64toI32(edge_compile_config._skip_dim_order),
169+
],
170+
partitioner=[VulkanPartitioner(compile_options)],
171+
generate_etrecord=args.etrecord,
172+
)
173+
else:
174+
# Standard non-quantized path
175+
# Export the model using torch.export
176+
program = torch.export.export(
177+
model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True
178+
)
179+
180+
# Transform and lower with Vulkan partitioner
181+
edge = to_edge_transform_and_lower(
182+
program,
183+
compile_config=edge_compile_config,
184+
transform_passes=[
185+
I64toI32(edge_compile_config._skip_dim_order),
186+
],
187+
partitioner=[VulkanPartitioner(compile_options)],
188+
generate_etrecord=args.etrecord,
189+
)
190+
191+
logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}")
192+
193+
exec_prog = edge.to_executorch(
194+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
195+
)
196+
197+
if args.etrecord:
198+
exec_prog.get_etrecord().save(args.etrecord)
199+
logging.info(f"Saved ETRecord to {args.etrecord}")
200+
201+
quant_tag = "q8" if args.quantize else "fp32"
202+
model_name = f"{args.model_name}_vulkan_{quant_tag}"
203+
save_pte_program(exec_prog, model_name, args.output_dir)
204+
logging.info(f"Model exported and saved as {model_name}.pte in {args.output_dir}")

0 commit comments

Comments
 (0)