diff --git a/examples/vulkan/README.md b/examples/vulkan/README.md new file mode 100644 index 00000000000..71fdd0e4183 --- /dev/null +++ b/examples/vulkan/README.md @@ -0,0 +1,80 @@ +# Vulkan Delegate Export Examples + +This directory contains scripts for exporting models with the Vulkan delegate in ExecuTorch. Vulkan delegation allows you to run your models on devices with Vulkan-capable GPUs, potentially providing significant performance improvements over CPU execution. + +## Scripts + +- `export.py`: Basic export script for models to use with Vulkan delegate +- `aot_compiler.py`: Advanced export script with quantization support + +## Usage + +### Basic Export + +```bash +python -m executorch.examples.vulkan.export -m -o +``` + +### Export with Quantization (Experimental) + +```bash +python -m executorch.examples.vulkan.aot_compiler -m -q -o +``` + +### Dynamic Shape Support + +```bash +python -m executorch.examples.vulkan.export -m -d -o +``` + +### Additional Options + +- `-s/--strict`: Export with strict mode (default: True) +- `-a/--segment_alignment`: Specify segment alignment in hex (default: 0x1000) +- `-e/--external_constants`: Save constants in external .ptd file (default: False) +- `-r/--etrecord`: Generate and save an ETRecord to the given file location + +## Examples + +```bash +# Export MobileNetV2 with Vulkan delegate +python -m executorch.examples.vulkan.export -m mobilenet_v2 -o ./exported_models + +# Export MobileNetV3 with quantization +python -m executorch.examples.vulkan.aot_compiler -m mobilenet_v3 -q -o ./exported_models + +# Export with dynamic shapes +python -m executorch.examples.vulkan.export -m mobilenet_v2 -d -o ./exported_models + +# Export with ETRecord for debugging +python -m executorch.examples.vulkan.export -m mobilenet_v2 -r ./records/mobilenet_record.etrecord -o ./exported_models +``` + +## Supported Operations + +The Vulkan delegate supports various operations including: + +- Basic arithmetic (add, subtract, multiply, divide) +- Activations (ReLU, Sigmoid, Tanh, etc.) +- Convolutions (Conv1d, Conv2d, ConvTranspose2d) +- Pooling operations (MaxPool2d, AvgPool2d) +- Linear/Fully connected layers +- BatchNorm, GroupNorm +- Various tensor operations (cat, reshape, permute, etc.) + +For a complete list of supported operations, refer to the Vulkan delegate implementation in the ExecuTorch codebase. + +## Debugging and Optimization + +If you encounter issues with Vulkan delegation: + +1. Use `-r/--etrecord` to generate an ETRecord for debugging +2. Check if your operations are supported by the Vulkan delegate +3. Ensure your Vulkan drivers are up to date +4. Try using the export script with `--strict False` if strict mode causes issues + +## Requirements + +- Vulkan runtime libraries (libvulkan.so.1) +- A Vulkan-capable GPU with appropriate drivers +- PyTorch with Vulkan support diff --git a/examples/vulkan/__init__.py b/examples/vulkan/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/vulkan/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/vulkan/aot_compiler.py b/examples/vulkan/aot_compiler.py new file mode 100644 index 00000000000..4f95ffa183a --- /dev/null +++ b/examples/vulkan/aot_compiler.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for compiling models with Vulkan delegation + +# pyre-unsafe + +import argparse +import logging + +import torch +from executorch.backends.transforms.convert_dtype_pass import I64toI32 +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer + +from ..models import MODEL_NAME_TO_MODEL +from ..models.model_factory import EagerModelFactory + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs, + quantizer: Quantizer, + dynamic_shapes=None, +) -> torch.nn.Module: + """Quantize a model and lower it with Vulkan delegation""" + compile_options = {} + if dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + ) + + program = torch.export.export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + parser.add_argument( + "-q", + "--quantize", + action="store_true", + required=False, + default=False, + help="Produce a quantized model. Note: Quantization support may vary by model.", + ) + parser.add_argument( + "-d", + "--delegate", + action="store_true", + required=False, + default=True, + help="Produce a Vulkan delegated model", + ) + parser.add_argument( + "-y", + "--dynamic", + action="store_true", + required=False, + default=False, + help="Enable dynamic shape support", + ) + parser.add_argument( + "-r", + "--etrecord", + required=False, + default="", + help="Generate and save an ETRecord to the given file location", + ) + parser.add_argument("-o", "--output_dir", default=".", help="output directory") + + args = parser.parse_args() + + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] + ) + + model = model.eval() + + if args.dynamic and dynamic_shapes is None: + logging.warning("Dynamic shapes requested but not available for this model.") + + dynamic_shapes_to_use = dynamic_shapes if args.dynamic else None + + # Configure Edge compilation + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + _check_ir_validity=True, + ) + + # Setup compile options + compile_options = {} + if dynamic_shapes_to_use is not None: + compile_options["require_dynamic_shapes"] = True + + if args.quantize: + logging.info("Quantization for Vulkan not fully supported yet. Using experimental path.") + try: + # Try to import quantization utilities if available + try: + from ..quantization.utils import get_quantizer_for_model + quantizer = get_quantizer_for_model(args.model_name) + except ImportError: + # If the specific utility isn't available, create a basic quantizer + logging.warning("Quantization utils not found. Using default quantizer.") + from torchao.quantization.pt2e.quantizer import get_default_quantizer + quantizer = get_default_quantizer() + + edge = quantize_and_lower_module( + model, example_inputs, quantizer, dynamic_shapes=dynamic_shapes_to_use + ) + except (ImportError, NotImplementedError) as e: + logging.error(f"Quantization failed: {e}") + logging.info("Falling back to non-quantized path") + # Export the model using torch.export + program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True + ) + + # Transform and lower with Vulkan partitioner + edge = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + else: + # Standard non-quantized path + # Export the model using torch.export + program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True + ) + + # Transform and lower with Vulkan partitioner + edge = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + + logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}") + + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + if args.etrecord: + exec_prog.get_etrecord().save(args.etrecord) + logging.info(f"Saved ETRecord to {args.etrecord}") + + quant_tag = "q8" if args.quantize else "fp32" + model_name = f"{args.model_name}_vulkan_{quant_tag}" + save_pte_program(exec_prog, model_name, args.output_dir) + logging.info(f"Model exported and saved as {model_name}.pte in {args.output_dir}") diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py new file mode 100644 index 00000000000..8e7e49ebe65 --- /dev/null +++ b/examples/vulkan/export.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting models to flatbuffer with the Vulkan delegate + +# pyre-unsafe + +import argparse +import logging +import torch + +from executorch.backends.transforms.convert_dtype_pass import I64toI32 +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.extension.export_util.utils import save_pte_program + +from ..models import MODEL_NAME_TO_MODEL +from ..models.model_factory import EagerModelFactory +from torch.export import export + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + + parser.add_argument( + "-s", + "--strict", + action=argparse.BooleanOptionalAction, + default=True, + help="whether to export with strict mode. Default is True", + ) + + parser.add_argument( + "-a", + "--segment_alignment", + required=False, + help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", + ) + + parser.add_argument( + "-e", + "--external_constants", + action=argparse.BooleanOptionalAction, + default=False, + help="Save constants in external .ptd file. Default is False", + ) + + parser.add_argument( + "-d", + "--dynamic", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable dynamic shape support. Default is False", + ) + + parser.add_argument( + "-r", + "--etrecord", + required=False, + default="", + help="Generate and save an ETRecord to the given file location", + ) + + parser.add_argument("-o", "--output_dir", default=".", help="output directory") + + args = parser.parse_args() + + if args.model_name not in MODEL_NAME_TO_MODEL: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) + + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] + ) + + # Prepare model + model.eval() + + # Setup compile options + compile_options = {} + if args.dynamic or dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + + # Configure Edge compilation + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + ) + + logging.info(f"Exporting model {args.model_name} with Vulkan delegate") + + # Export the model using torch.export + if dynamic_shapes is not None: + program = export(model, example_inputs, dynamic_shapes=dynamic_shapes, strict=args.strict) + else: + program = export(model, example_inputs, strict=args.strict) + + # Transform and lower with Vulkan partitioner + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + + logging.info(f"Exported and lowered graph:\n{edge_program.exported_program().graph}") + + # Configure backend options + backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) + if args.segment_alignment is not None: + backend_config.segment_alignment = int(args.segment_alignment, 16) + + # Create executorch program + exec_prog = edge_program.to_executorch(config=backend_config) + + # Save ETRecord if requested + if args.etrecord: + exec_prog.get_etrecord().save(args.etrecord) + logging.info(f"Saved ETRecord to {args.etrecord}") + + # Save the program + model_name = f"{args.model_name}_vulkan" + save_pte_program(exec_prog, model_name, args.output_dir) + logging.info(f"Model exported and saved as {model_name}.pte in {args.output_dir}") + + +if __name__ == "__main__": + with torch.no_grad(): + main() # pragma: no cover