diff --git a/backends/openvino/CMakeLists.txt b/backends/openvino/CMakeLists.txt index f5b957da881..736ed6d8603 100644 --- a/backends/openvino/CMakeLists.txt +++ b/backends/openvino/CMakeLists.txt @@ -53,30 +53,6 @@ target_sources( executorch_target_link_options_shared_lib(openvino_backend) -if(EXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER) - # Build executor runner binary for openvino backend - list(APPEND openvino_executor_runner_libs openvino_backend executorch) - - set(_openvino_executor_runner__srcs - ${EXECUTORCH_ROOT}/examples/portable/executor_runner/executor_runner.cpp - ${EXECUTORCH_ROOT}/extension/data_loader/file_data_loader.cpp - ${EXECUTORCH_ROOT}/extension/evalue_util/print_evalue.cpp - ${EXECUTORCH_ROOT}/extension/runner_util/inputs.cpp - ${EXECUTORCH_ROOT}/extension/runner_util/inputs_portable.cpp - ) - add_executable(openvino_executor_runner ${_openvino_executor_runner__srcs}) - - list(APPEND openvino_executor_runner_libs) - - target_link_libraries( - openvino_executor_runner gflags portable_ops_lib - ${openvino_executor_runner_libs} - ) - target_compile_options( - openvino_executor_runner PUBLIC ${_common_compile_options} - ) -endif() - # Install OpenVINO backend library to the lib directory install( TARGETS openvino_backend diff --git a/backends/openvino/README.md b/backends/openvino/README.md index a67cf12eca2..5ce38ade56f 100644 --- a/backends/openvino/README.md +++ b/backends/openvino/README.md @@ -18,6 +18,11 @@ For more information on the supported hardware, please refer to [OpenVINO System executorch ├── backends │ └── openvino +│ ├── quantizer +│ ├── observers +│ └── nncf_observers.py +│ ├── __init__.py +│ └── quantizer.py │ ├── runtime │ ├── OpenvinoBackend.cpp │ └── OpenvinoBackend.h @@ -42,11 +47,23 @@ executorch Before you begin, ensure you have openvino installed and configured on your system. -### Build OpenVINO from Source +### Use OpenVINO from Release Packages + +1. Download the OpenVINO release package from [here](https://docs.openvino.ai/2025/get-started/install-openvino.html). Make sure to select your configuration and click on **OpenVINO Archives** under the distribution section to download the appropriate archive for your platform. + +2. Extract the release package from the archive and set the environment variables. + + ```bash + tar -zxf openvino_toolkit_.tgz + cd openvino_toolkit_ + source setupvars.sh + ``` + +### (Optional) Build OpenVINO from Source ```bash git clone https://github.com/openvinotoolkit/openvino.git -cd openvino && git checkout b16b776ac119dafda51f69a80f1e6b7376d02c3b +cd openvino git submodule update --init --recursive sudo ./install_build_dependencies.sh mkdir build && cd build @@ -59,44 +76,45 @@ cd source setupvars.sh ``` -### Use OpenVINO from Release Packages - -1. Download the OpenVINO release package from [here](https://docs.openvino.ai/2025/get-started/install-openvino.html). Make sure to select your configuration and click on **OpenVINO Archives** under the distribution section to download the appropriate archive for your platform. - -2. Extract the release package from the archive and set the environment variables. - - ```bash - tar -zxf openvino_toolkit_.tgz - cd openvino_toolkit_ - source setupvars.sh - ``` - For more information about OpenVINO build, refer to the [OpenVINO Build Instructions](https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build_linux.md). ### Setup Follow the steps below to setup your build environment: -1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](https://pytorch.org/executorch/main/getting-started-setup#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. -2. **Setup OpenVINO Backend Environment** -- Install the dependent libs. Ensure that you are inside `executorch/backends/openvino/` directory +1. **Create a Virtual Environment** +- Create a virtual environment and activate it by executing the commands below. ```bash - pip install -r requirements.txt + python -m venv env + source env/bin/activate ``` - Note: To achieve optimal performance with NNCF quantization, you should install the latest development version of NNCF (version 2.16.0.dev0+191b53d9 or higher). -3. Navigate to `scripts/` directory. - -4. **Build OpenVINO Backend C++ Libraries and Executor Runner**: Once the prerequisites are in place, run the `openvino_build.sh` script to start the build process. By default, OpenVINO backend will be built under `cmake-out/backends/openvino/` as `libopenvino_backend.a` - +2. **Clone ExecuTorch Repository from Github** +- Clone Executorch repository by executing the command below. ```bash - ./openvino_build.sh + git clone --recurse-submodules https://github.com/pytorch/executorch.git ``` - **Build OpenVINO Backend Python Package with Pybindings**: To build and install the OpenVINO backend Python package with Python bindings, run the `openvino_build.sh` script with the `--enable_python` argument. This will compile and install the ExecuTorch Python package with the OpenVINO backend into your Python environment. This option will also enable python bindings required to execute OpenVINO backend tests and `aot_optimize_and_infer.py` script inside `executorch/examples/openvino` folder. - +3. **Build ExecuTorch with OpenVINO Backend** +- Ensure that you are inside `executorch/backends/openvino/scripts` directory. The following command builds and installs ExecuTorch with the OpenVINO backend, also compiles the C++ runtime libraries and binaries into `/cmake-out` for quick inference testing. ```bash + openvino_build.sh + ``` +- Optionally, `openvino_build.sh` script can be used to build python package or C++ libraries/binaries seperately. + + **Build OpenVINO Backend Python Package with Pybindings**: To build and install the OpenVINO backend Python package with Python bindings, run the `openvino_build.sh` script with the `--enable_python` argument as shown in the below command. This will compile and install the ExecuTorch Python package with the OpenVINO backend into your Python environment. This option will also enable python bindings required to execute OpenVINO backend tests and `aot_optimize_and_infer.py` script inside `executorch/examples/openvino` folder. + ```bash ./openvino_build.sh --enable_python ``` + **Build C++ Runtime Libraries for OpenVINO Backend**: Run the `openvino_build.sh` script with the `--cpp_runtime` flag to build the C++ runtime libraries as shown in the below command. The compiled libraries files and binaries can be found in the `/cmake-out` directory. The binary located at `/cmake-out/executor_runner` can be used to run inference with vision models. + ```bash + ./openvino_build.sh --cpp_runtime + ``` + **Build C++ Llama Runner**: First, ensure the C++ runtime libraries are built by following the earlier instructions. Then, run the `openvino_build.sh` script with the `--llama_runner flag` to compile the LlaMA runner as shown the below command, which enables executing inference with models exported using export_llama. The resulting binary is located at: `/cmake-out/examples/models/llama/llama_main` + ```bash + ./openvino_build.sh --llama_runner + ``` + +For more information about ExecuTorch environment setup, refer to the [Environment Setup](https://pytorch.org/executorch/main/getting-started-setup#environment-setup) guide. ### Run diff --git a/backends/openvino/partitioner.py b/backends/openvino/partitioner.py index 4975dc657c6..0d407e33f6e 100644 --- a/backends/openvino/partitioner.py +++ b/backends/openvino/partitioner.py @@ -26,6 +26,13 @@ from torch.fx.passes.operator_support import OperatorSupportBase +class PatternNode: + op_types: dict[str, Optional[list]] = {} + + def __init__(self): + self.op_types = {} + + class OpenvinoOperatorsSupport(OperatorSupportBase): extended_support_dict = { "torch.ops.dim_order_ops._clone_dim_order.default": None, @@ -36,6 +43,7 @@ def __init__( self, op_types_to_skip: Optional[set] = None, op_names_to_skip: Optional[set] = None, + enabled_ops_by_name: Optional[set] = None, ) -> None: """ Initializes the OpenvinoOperatorsSupport class. @@ -47,9 +55,12 @@ def __init__( op_types_to_skip = set() if op_names_to_skip is None: op_names_to_skip = set() + if enabled_ops_by_name is None: + enabled_ops_by_name = set() self._op_types_to_skip = op_types_to_skip self._op_names_to_skip = op_names_to_skip + self._enabled_ops_by_name = enabled_ops_by_name def is_node_supported(self, _, node: torch.fx.Node) -> bool: """ @@ -66,6 +77,10 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: op_type = node.target.__name__ else: op_type = str(node.target) + + if node.name in self._enabled_ops_by_name: + return True + supported_ops = ( OperatorSupport(options)._support_dict | self.extended_support_dict ) @@ -105,6 +120,7 @@ def __init__( self.delegation_spec = DelegationSpec(OpenvinoBackend.__name__, compile_spec) self._op_types_to_skip = op_types_to_skip self._op_names_to_skip = op_names_to_skip + self._enabled_ops_by_name: set = set() def ops_to_not_decompose( self, @@ -123,9 +139,72 @@ def ops_to_not_decompose( torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.stack.default, ] return (ops_not_decompose, None) + def check_pattern( + self, node: torch.fx.Node, pattern: type[PatternNode], enabled_ops: list + ) -> bool: + if node.op == "call_function": + if ("call_function" + ":" + str(node.target.__name__)) in pattern.op_types: # type: ignore[union-attr] + pt_input_nodes = node.all_input_nodes + pattern_input_ops = pattern.op_types[ + "call_function" + ":" + str(node.target.__name__) # type: ignore[union-attr] + ] + if pattern_input_ops is None: + enabled_ops.append(node) + return True + if len(pt_input_nodes) != len(pattern_input_ops): + return False + for i in range(len(pt_input_nodes)): + if not self.check_pattern( + pt_input_nodes[i], pattern_input_ops[i], enabled_ops + ): + return False + enabled_ops.append(node) + return True + elif node.op == "get_attr": + if "get_attr" in pattern.op_types: + return True + else: + return False + elif node.op == "placeholder": + if "placeholder" in pattern.op_types: + return True + else: + return False + return False + + def capture_nncf_patterns(self, graph_module: torch.fx.GraphModule): + const_node = PatternNode + const_node.op_types["get_attr"] = None + const_node.op_types["placeholder"] = None + bitwise_right_shift_node = PatternNode + bitwise_right_shift_node.op_types[ + "call_function:aten.bitwise_right_shift.Tensor_Scalar" + ] = [const_node] + bitwise_and_node = PatternNode + bitwise_and_node.op_types["call_function:aten.bitwise_and.Scalar"] = [ + const_node + ] + stack_node = PatternNode + stack_node.op_types["call_function:aten.stack.default"] = [ + bitwise_and_node, + bitwise_right_shift_node, + ] + + for node in graph_module.graph.nodes: + if ( + str(node.op) == "call_function" + and str(node.target.__name__) == "aten.stack.default" + ): + enabled_ops: list = [] + pattern_match = self.check_pattern(node, stack_node, enabled_ops) + if pattern_match: + for pattern_op in enabled_ops: + self._enabled_ops_by_name.add(pattern_op.name) + def partition(self, exported_program: ExportedProgram) -> PartitionResult: """ Partitions an exported program into supported and unsupported segments. @@ -133,9 +212,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: :param exported_program: The exported program. :return: A PartitionResult containing the partitioned graph and delegation tags. """ + self.capture_nncf_patterns(exported_program.graph_module) partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), + OpenvinoOperatorsSupport( + self._op_types_to_skip, + self._op_names_to_skip, + self._enabled_ops_by_name, + ), allows_single_node_partition=True, ) partition_list = partitioner.propose_partitions() diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py index df038483f2f..5aae52ef3e8 100644 --- a/backends/openvino/quantizer/__init__.py +++ b/backends/openvino/quantizer/__init__.py @@ -1,3 +1,3 @@ -from .quantizer import OpenVINOQuantizer, quantize_model +from .quantizer import OpenVINOQuantizer, QuantizationMode, quantize_model -__all__ = ["OpenVINOQuantizer", "quantize_model"] +__all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"] diff --git a/backends/openvino/quantizer/observers.py b/backends/openvino/quantizer/observers.py new file mode 100644 index 00000000000..6cda4561604 --- /dev/null +++ b/backends/openvino/quantizer/observers.py @@ -0,0 +1,186 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped] + get_tensor_constant_from_node, +) +from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped] + constant_update, + module_insertion, + node_removal, +) +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, +) +from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped] + do_integer_quantization, +) +from nncf.tensor.tensor import Tensor as NNCFTensor # type: ignore[import-untyped] +from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped] + PTTargetPoint, + TargetType, +) +from nncf.torch.quantization.layers import ( # type: ignore[import-untyped] + BaseWeightsDecompressor, + INT4AsymmetricWeightsDecompressor, + INT4SymmetricWeightsDecompressor, + INT8AsymmetricWeightsDecompressor, + INT8SymmetricWeightsDecompressor, +) +from torchao.quantization.pt2e import ObserverBase + + +class WeightObserverBase(ObserverBase, ABC): + """ + Base implementation of an NNCF observer that defines the rules for compressing layer weights into the OpenVINO representation. + """ + + def __init__( + self, + wc_param: WeightCompressionParameters, + dtype: torch.dtype, + **kwargs, + ) -> None: + """ + :param wc_param: Weight compression parameters container. + :param dtype: target dtype for the quantization. + """ + super().__init__(dtype=dtype, is_dynamic=False) + self._wc_param = wc_param + + def calculate_qparams( # type: ignore[override] + self, + weight: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Calculates quantization parameters: quantized weight, quantization scale and quantization zero point. + + :param weight: FP weight to be used for calculating qparams. + :return: A tuple containing the quantized weight, quantization scale and quantization zero point. + """ + wc_param = self._wc_param + wc_config = wc_param.compression_config + reduction_axes = wc_param.reduction_axes + q_weight, scale, zp = do_integer_quantization( + NNCFTensor(weight), wc_config, reduction_axes=reduction_axes + ) + zp = zp.data if zp is not None else None + return q_weight.data, scale.data, zp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + def convert( + self, model: torch.fx.GraphModule, observer_node: torch.fx.Node + ) -> None: + """ + Replaces the given observer node from the given model with a quantized + weight and a OpenVINO specific decompression module. + + :param model: A `torch.fx.GraphModule` representing the statically traced model + with observer nodes attached and calibrated. + :param observer_node: The `torch.fx.Node` corresponding to the observer module for + the weight that is being transformed into a compressed representation. + """ + weight_node = observer_node.args[0] + original_weight = get_tensor_constant_from_node(weight_node, model) + q_weight, scale, zero_point = self.calculate_qparams(original_weight) + + decompressor = self._create_decompressor( + scale, zero_point, q_weight, original_weight + ) + packed_q_weight = decompressor.pack_weight(q_weight) + + # Weight port id is 0 since observer is inserted for a single weight only. + constant_update(model, observer_node, packed_q_weight, input_port_id=0) + + compressed_weight_name = observer_node.all_input_nodes[0].name + decompressor_suffix = "_".join( + compressed_weight_name.replace(".", "_").split("_")[:-2] + ) + decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" + + module_insertion( + model, + decompressor, + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, + target_node_name=compressed_weight_name, + ) + ], + decompressor_name, + ) + node_removal(model, observer_node, 0) + + @abstractmethod + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + """ + Returns a respective NNCF decompressor for different types of quantization. + + :param scale: Calculated scale quantization parameter. + :param zero_point: Calculated zero_point quantization parameter. + :param q_weight: Calculated quantized weight. + :param original_weight: FP weight. + :return: NNCF observer according to the qmode which creates the decompression subgraph supported by OpenVINO. + """ + + +class INT4WeightObserver(WeightObserverBase): + """ + OpenVINO INT4 Weight Compression observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT4SymmetricWeightsDecompressor( + scale, q_weight.shape, original_weight.shape, original_weight.dtype + ) + return INT4AsymmetricWeightsDecompressor( + scale, + zero_point, + q_weight.shape, + original_weight.shape, + original_weight.dtype, + ) + + +class INT8WeightObserver(WeightObserverBase): + """ + OpenVINO INT8 Weight Compression per channel observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype) + return INT8AsymmetricWeightsDecompressor( + scale, zero_point, original_weight.dtype + ) diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index edce272ff9b..bef1ef3274f 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -15,8 +15,17 @@ import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped] import torch.fx - +from executorch.backends.openvino.quantizer.observers import ( + INT4WeightObserver, + INT8WeightObserver, +) from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, +) +from nncf.quantization.quantize_model import ( # type: ignore[import-untyped] + get_weight_compression_configuration, +) from torchao.quantization.pt2e import ( HistogramObserver, PerChannelMinMaxObserver, @@ -30,7 +39,8 @@ Quantizer, SharedQuantizationSpec, ) -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + +QUANT_ANNOTATION_KEY = "quantization_annotation" class QuantizationMode(Enum): @@ -40,11 +50,19 @@ class QuantizationMode(Enum): - INT8_SYM: INT8 symmetric quantization for both activations and weights. - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + - INT8WO_SYM: INT8 symmetric quantization for weights only. + - INT8WO_ASYM: INT8 asymmetric quantization for weights only. + - INT4WO_SYM: INT4 symmetric quantization for weights only. + - INT4WO_ASYM: INT4 asymmetric quantization for weights only """ INT8_SYM = "int8_sym" INT8_MIXED = "int8_mixed" INT8_TRANSFORMER = "int8_transformer" + INT8WO_SYM = "int8wo_sym" + INT8WO_ASYM = "int8wo_asym" + INT4WO_SYM = "int4wo_sym" + INT4WO_ASYM = "int4wo_asym" class OpenVINOQuantizer(Quantizer): @@ -53,10 +71,17 @@ class OpenVINOQuantizer(Quantizer): optimally for the inference via OpenVINO. """ + WEIGHTS_ONLY_COMPRESSION_MODES = ( + QuantizationMode.INT4WO_SYM, + QuantizationMode.INT4WO_ASYM, + QuantizationMode.INT8WO_SYM, + QuantizationMode.INT8WO_ASYM, + ) + def __init__( self, *, - mode: Optional[QuantizationMode] = QuantizationMode.INT8_SYM, + mode: QuantizationMode = QuantizationMode.INT8_SYM, **kwargs, ): """ @@ -65,22 +90,37 @@ def __init__( - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models Default value is INT8_SYM. + - INT4_SYM: Symmetric INT4 Weights-Only Compression + - INT4_ASYM: Asymmetric INT4 Weights-Only Compression :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm. """ - if mode == QuantizationMode.INT8_SYM: - preset = quantization.structs.QuantizationPreset.PERFORMANCE - model_type = None - elif mode == QuantizationMode.INT8_MIXED: - preset = quantization.structs.QuantizationPreset.MIXED - model_type = None + self.mode = mode + if self.mode not in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + if mode == QuantizationMode.INT8_SYM: + preset = quantization.structs.QuantizationPreset.PERFORMANCE + model_type = None + elif mode == QuantizationMode.INT8_MIXED: + preset = quantization.structs.QuantizationPreset.MIXED + model_type = None + else: + preset = None + model_type = nncf.parameters.ModelType.TRANSFORMER + self._algo = ( + nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( + preset=preset, model_type=model_type, **kwargs + ) + ) else: - preset = None - model_type = nncf.parameters.ModelType.TRANSFORMER - self._min_max_algo = ( - nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( - preset=preset, model_type=model_type, **kwargs + weight_compression_configuration = get_weight_compression_configuration( + mode.value.replace( + "wo", "" + ), # Mode value has to match NNCF CompressWeightsMode + **kwargs, + ) + subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve + self._algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression( + subset_size=subset_size, **weight_compression_configuration ) - ) def set_ignored_scope( self, @@ -101,7 +141,7 @@ def set_ignored_scope( :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match in the model graph. """ - self._min_max_algo.set_ignored_scope( + self._algo.set_ignored_scope( nncf.IgnoredScope( names=names or [], patterns=patterns or [], @@ -114,27 +154,73 @@ def set_ignored_scope( def get_nncf_quantization_setup( self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph ) -> quantization.quantizer_setup.SingleConfigQuantizerSetup: - self._min_max_algo._set_backend_entity(model) - return self._min_max_algo.find_quantization_setup(model, nncf_graph) + self._algo._set_backend_entity(model) + return self._algo.find_quantization_setup(model, nncf_graph) - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) - quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + def _annotate_weight_compression( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with weight-only quantization specs. - graph = model.graph - node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( - defaultdict(QuantizationAnnotation) + Identifies compressible nodes in the NNCF graph and attaches the corresponding + TorchAO quantization specifications to their weight edges for later transformation. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with weight compression annotations. + """ + self._algo.set_backend_entity(model) + all_wc_params, _ = self._algo.get_weight_compression_parameters( + model, nncf_graph ) + for wc_param in all_wc_params: + node_with_weight = wc_param.node_with_weight + target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, node_with_weight.node_name + ) + annotation = node_vs_torch_annotation[target_node] + edge_or_node = self._get_weight_edge(target_node, nncf_graph) + qspec = self._get_torch_ao_qspec_from_nncf_config_for_wc(wc_param=wc_param) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + return node_vs_torch_annotation + + def _annotate_post_training_quantization( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with post-training quantization configurations. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with post-training quantization annotations. + """ + quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + for qp in quantization_setup.quantization_points.values(): edge_or_node, annotation = self._get_edge_or_node_and_annotation( graph, nncf_graph, qp, node_vs_torch_annotation ) - qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_qp(qp) + qspec: QuantizationSpecBase = ( + self._get_torch_ao_qspec_from_nncf_config_for_ptq(qp) + ) self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) for quantizer_ids in quantization_setup.unified_scale_groups.values(): - root_quantizer_id = self._get_unified_scales_root_quantizer_id( nncf_graph, quantizer_ids, quantization_setup ) @@ -145,14 +231,12 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: for q_id in quantizer_ids ): qps = [ - quantization_setup.quantization_points[q_id] - for q_id in quantizer_ids + quantization_setup.quantization_points[qid] for qid in quantizer_ids ] - msg = ( + raise nncf.InternalError( "Different quantization configs are set to one unified scale group:" f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" ) - raise nncf.InternalError(msg) root_target_node = nncf_fx.node_utils.get_graph_node_by_name( graph, root_qp.insertion_point.target_node_name @@ -165,16 +249,35 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if quantizer_id == root_quantizer_id: continue - qspec = SharedQuantizationSpec(root_edge_or_node) + qspec = SharedQuantizationSpec(root_edge_or_node) # type: ignore[assignment] qp = quantization_setup.quantization_points[quantizer_id] edge_or_node, annotation = self._get_edge_or_node_and_annotation( graph, nncf_graph, qp, node_vs_torch_annotation ) self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + return node_vs_torch_annotation + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + graph = model.graph + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( + defaultdict(QuantizationAnnotation) + ) + + if self.mode in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + node_vs_torch_annotation = self._annotate_weight_compression( + model, graph, nncf_graph, node_vs_torch_annotation + ) + else: + node_vs_torch_annotation = self._annotate_post_training_quantization( + model, graph, nncf_graph, node_vs_torch_annotation + ) + for node, annotation in node_vs_torch_annotation.items(): - assert Q_ANNOTATION_KEY not in node.meta - node.meta[Q_ANNOTATION_KEY] = annotation + assert QUANT_ANNOTATION_KEY not in node.meta + node.meta[QUANT_ANNOTATION_KEY] = annotation + return model @staticmethod @@ -236,6 +339,35 @@ def _get_edge_or_node_and_annotation( edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) return edge_or_node, annotation + @staticmethod + def _get_weight_edge( + target_node: torch.fx.Node, + nncf_graph: NNCFGraph, + ) -> tuple[torch.fx.Node, torch.fx.Node]: + """ + Returns the FX node corresponding to the weight tensor input of a given operator node. + Uses the NNCF graph to identify which input port of the target node holds the weight. + If multiple weight ports are present, a warning is issued and only the first one is used. + + :param target_node: FX node representing a weighted operation (e.g., Linear, Conv). + :param nncf_graph: NNCFGraph used to determine weight port indices. + :return: Edge represented by a Tuple of (weight_node, target_node), where weight_node is the FX node supplying the weight. + """ + nncf_node = nncf_graph.get_node_by_name(target_node.name) + weights_ports_ids = nncf.torch.model_graph_manager.get_weight_tensor_port_ids( + nncf_node, nncf_graph + ) + if len(weights_ports_ids) > 1: + # TODO(dlyakhov): support quantization for nodes with several weights + nncf.common.logging.nncf_logger.warning( + f"Quantization of the weighted node {target_node.name}" + " is not yet supported by the OpenVINOQuantizer." + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." + f" Quantizable weights are located on ports: {weights_ports_ids}." + ) + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] + return (weight_node, target_node) + @staticmethod def _get_edge_or_node( target_node: torch.fx.Node, @@ -252,22 +384,7 @@ def _get_edge_or_node( """ ip = qp.insertion_point if qp.is_weight_quantization_point(): - nncf_node = nncf_graph.get_node_by_name(target_node.name) - weights_ports_ids = ( - nncf.torch.model_graph_manager.get_weight_tensor_port_ids( - nncf_node, nncf_graph - ) - ) - if len(weights_ports_ids) > 1: - # TODO(dlyakhov): support quantization for nodes with several weights - nncf.common.logging.nncf_logger.warning( - f"Quantization of the weighted node {target_node.name}" - " is not yet supported by the OpenVINOQuantizer." - f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." - f" Quantizable weights are located on ports: {weights_ports_ids}." - ) - weight_node = target_node.all_input_nodes[weights_ports_ids[0]] - return (weight_node, target_node) + OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph) if ip.input_port_id is None: return target_node @@ -294,22 +411,78 @@ def _fill_torch_ao_annotation( annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec @staticmethod - def _get_torch_ao_qspec_from_qp( + def _get_torch_ao_qspec_from_nncf_config_for_wc( + wc_param: WeightCompressionParameters, + ) -> QuantizationSpec: + """ + Returns a TorchAO QuantizationSpec based on NNCF weight compression parameter. + + :param wc_param: NNCF Weight compression parameters for the node. + :return: A TorchAO QuantizationSpec. + """ + observer: Type[UniformQuantizationObserverBase] + + extra_args: Dict[str, Any] = {} + + qmode = wc_param.compression_config.mode + extra_args["wc_param"] = wc_param + is_asym_mode = wc_param.compression_config.is_asym_mode + if qmode in [ + nncf.CompressWeightsMode.INT4_ASYM, + nncf.CompressWeightsMode.INT4_SYM, + ]: + observer = INT4WeightObserver # type: ignore[type-abstract] + quant_min = -8 if not is_asym_mode else 0 + quant_max = 7 if not is_asym_mode else 15 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + else: + observer = INT8WeightObserver # type: ignore[type-abstract] + quant_min = -128 if not is_asym_mode else 0 + quant_max = 127 if not is_asym_mode else 255 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + return QuantizationSpec( + dtype=dtype, + observer_or_fake_quant_ctr=observer.with_args(**extra_args), + quant_min=quant_min, + quant_max=quant_max, + qscheme=torch_qscheme, + ch_axis=channel_axis, + is_dynamic=False, + ) + + @staticmethod + def _get_torch_ao_qspec_from_nncf_config_for_ptq( qp: quantization.quantizer_setup.QuantizationPointBase, ) -> QuantizationSpec: """ - Retrieves the quantization configuration from the given quantization point and - converts it into a QuantizationSpec. + Returns a TorchAO QuantizationSpec based on NNCF quantization point. - :param qp: An instance of QuantizationPointBase. - :return: A QuantizationSpec retrieved and converted from the quantization point. + :param qp: Quantization point from NNCF. + :return: A TorchAO QuantizationSpec. """ + observer: Type[UniformQuantizationObserverBase] + # Eps value is copied from nncf/torch/quantization/layers.py - extra_args = {"eps": 1e-16} - qconfig = qp.qconfig - is_weight = qp.is_weight_quantization_point() + extra_args: Dict[str, Any] = {"eps": 1e-16} - observer: Type[UniformQuantizationObserverBase] + is_weight = qp.is_weight_quantization_point() + qconfig = qp.qconfig + dtype = torch.int8 + quant_min = None + quant_max = None + channel_axis = None if qconfig.per_channel: torch_qscheme = ( @@ -329,6 +502,11 @@ def _get_torch_ao_qspec_from_qp( quant_max = 127 dtype = torch.int8 channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_channel_affine + ) else: observer = ( HistogramObserver diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt index 316633e9004..519818d0aac 100644 --- a/backends/openvino/requirements.txt +++ b/backends/openvino/requirements.txt @@ -1,2 +1,2 @@ transformers -git+https://github.com/openvinotoolkit/nncf@6b0fc1c#egg=nncf +git+https://github.com/openvinotoolkit/nncf@3d753ac#egg=nncf diff --git a/backends/openvino/runtime/OpenvinoBackend.cpp b/backends/openvino/runtime/OpenvinoBackend.cpp index 8ec40d7f7c6..bac006ce916 100644 --- a/backends/openvino/runtime/OpenvinoBackend.cpp +++ b/backends/openvino/runtime/OpenvinoBackend.cpp @@ -114,6 +114,26 @@ exr::Error OpenvinoBackend::execute( ov_type, input_shape, input_tensor.mutable_data_ptr()); infer_request->set_input_tensor(i, ov_input_tensor); + + if (args[i]->isInt()) { + int64_t* val = &(args[i]->payload.copyable_union.as_int); + + // Create OpenVINO tensor from integer input + ov::Tensor ov_input_tensor(ov::element::i64, ov::Shape{1}, val); + infer_request->set_input_tensor(i, ov_input_tensor); + } else { + auto input_tensor = args[i]->toTensor(); + ov::Shape input_shape( + input_tensor.sizes().begin(), input_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = + convert_to_openvino_type(input_tensor.scalar_type()); + ov::Tensor ov_input_tensor( + ov_type, input_shape, input_tensor.mutable_data_ptr()); + + infer_request->set_input_tensor(i, ov_input_tensor); + } } // Set outputs @@ -165,10 +185,14 @@ ov::element::Type OpenvinoBackend::convert_to_openvino_type( switch (scalar_type) { case exa::ScalarType::Float: return ov::element::f32; + case exa::ScalarType::Half: + return ov::element::f16; case exa::ScalarType::Int: return ov::element::i32; case exa::ScalarType::Char: return ov::element::i8; + case exa::ScalarType::Byte: + return ov::element::u8; case exa::ScalarType::Long: return ov::element::i64; case exa::ScalarType::Bool: diff --git a/backends/openvino/scripts/openvino_build.sh b/backends/openvino/scripts/openvino_build.sh index 5a26f0b6dae..6d7853b96e5 100755 --- a/backends/openvino/scripts/openvino_build.sh +++ b/backends/openvino/scripts/openvino_build.sh @@ -7,55 +7,106 @@ set -e EXECUTORCH_ROOT=$(realpath "$(dirname "$0")/../../..") echo EXECUTORCH_ROOT=${EXECUTORCH_ROOT} -main() { - build_type=${1:-"--cpp_runtime"} +install_requirements() { + echo "Installing Requirements For OpenVINO Backend" + cd "$EXECUTORCH_ROOT" + pip install -r backends/openvino/requirements.txt +} - # If the first arguments is --cpp_runtime (default), build libraries for C++ runtime - if [[ -z "$build_type" || "$build_type" == "--cpp_runtime" ]]; then - echo "Building C++ Runtime Libraries" +build_cpp_runtime() { + echo "Building C++ Runtime Libraries" + + # Set build directory + local build_dir="cmake-out" + + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + rm -rf "${build_dir}" + + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ + -B"${build_dir}" + + + # Build the project + cmake --build ${build_dir} --target install --config Release -j$(nproc) +} - # Set build directory - local build_dir="cmake-out" +build_llama_runner() { + echo "Building Export Llama Runner" - # Create and enter the build directory - cd "$EXECUTORCH_ROOT" - rm -rf "${build_dir}" + # Set build directory + local build_dir="cmake-out" - # Configure the project with CMake - # Note: Add any additional configuration options you need here - cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ - -DCMAKE_BUILD_TYPE=Release \ - -DEXECUTORCH_BUILD_OPENVINO=ON \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER=ON \ - -B"${build_dir}" + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE=Release \ + -B"${build_dir}"/examples/models/llama \ + examples/models/llama + # Build the export llama runner + cmake --build cmake-out/examples/models/llama -j$(nproc) --config Release +} - # Build the project - cmake --build ${build_dir} --target install --config Release -j$(nproc) +build_python_enabled() { + echo "Building Python Package with Pybinding" - # If the first arguments is --enable_python, build python package with python bindings - elif [[ "$build_type" == "--enable_python" ]]; then - echo "Building Python Package with Pybinding" + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + ./install_executorch.sh --clean + + # Set parameters to configure the project with CMake + # Note: Add any additional configuration options you need here + export CMAKE_ARGS="-DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON" + export CMAKE_BUILD_ARGS="--target openvino_backend" - # Create and enter the build directory - cd "$EXECUTORCH_ROOT" - ./install_executorch.sh --clean + # Build the package + ./install_executorch.sh --minimal - # Set parameters to configure the project with CMake - # Note: Add any additional configuration options you need here - export CMAKE_ARGS="-DEXECUTORCH_BUILD_OPENVINO=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON" - export CMAKE_BUILD_ARGS="--target openvino_backend" + # Install torchao + pip install third-party/ao +} + +main() { + build_type=${1:-"--build_all"} + + # If the first arguments is --build_all (default), build python package, C++ runtime, and llama runner binary + if [[ -z "$build_type" || "$build_type" == "--build_all" ]]; then + install_requirements + build_python_enabled + build_cpp_runtime + build_llama_runner - # Build the package - ./install_executorch.sh --minimal + # If the first arguments is --cpp_runtime, build libraries for C++ runtime + elif [[ "$build_type" == "--cpp_runtime" ]]; then + build_cpp_runtime - # Install torchao - pip install third-party/ao + # If the first arguments is --llama_runner, build export llama runner binary + # Note: c++ runtime with openvino backend should be built before building export llama runner + elif [[ "$build_type" == "--llama_runner" ]]; then + build_llama_runner + + # If the first arguments is --enable_python, build python package with python bindings + elif [[ "$build_type" == "--enable_python" ]]; then + install_requirements + build_python_enabled else echo "Error: Argument is not valid: $build_type" diff --git a/docs/source/build-run-openvino.md b/docs/source/build-run-openvino.md index d06a6eb82c8..9b4c48fee5a 100644 --- a/docs/source/build-run-openvino.md +++ b/docs/source/build-run-openvino.md @@ -92,7 +92,7 @@ The exported model will be saved as 'resnet50.pte' in the current directory. ### Build C++ OpenVINO Examples -After building the OpenVINO backend following the [instructions](#setup) above, the executable will be saved in `/cmake-out/backends/openvino/`. +After building the OpenVINO backend following the [instructions](#setup) above, the executable will be saved in `/cmake-out/`. The executable requires a model file (`.pte` file generated in the aot step) and the number of inference executions. @@ -101,7 +101,7 @@ The executable requires a model file (`.pte` file generated in the aot step) and Run inference with a given model for 10 executions: ``` -./openvino_executor_runner \ +./executor_runner \ --model_path=model.pte \ --num_executions=10 ``` diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index e7c73c0cffc..db0e38b8b74 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -189,6 +189,13 @@ if(TARGET mpsdelegate) executorch_target_link_options_shared_lib(mpsdelegate) endif() +# Openvino backend +if(TARGET openvino_backend) + find_package(OpenVINO REQUIRED) + list(APPEND link_libraries openvino_backend) + executorch_target_link_options_shared_lib(openvino_backend) +endif() + if(TARGET coremldelegate) find_library(SQLITE_LIBRARY sqlite3) list( diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index 5f7f4505c45..0d1728a0c6c 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -94,6 +94,8 @@ Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The p +[Please visit this section to try it on OpenVINO backend](../../openvino/llama/README.md). + ## Llama 3/3.1 8B Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model. diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3369b9bd97b..7fa9357f23b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -36,12 +36,14 @@ from executorch.extension.llm.export.partitioner_lib import ( get_coreml_partitioner, get_mps_partitioner, + get_openvino_partitioner, get_qnn_partitioner, get_vulkan_partitioner, get_xnnpack_partitioner, ) from executorch.extension.llm.export.quantizer_lib import ( get_coreml_quantizer, + get_ov_quantizer, get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, @@ -203,6 +205,8 @@ def build_args_parser() -> argparse.ArgumentParser: choices=[ "xnnpack_dynamic", "xnnpack_dynamic_qc4", + "openvino_4wo", + "openvino_8wo", "qnn_8a8w", "qnn_16a16w", "qnn_16a4w", @@ -471,6 +475,14 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Delegate llama2 to qnn backend (Qualcomm), please use it --kv_cahce=True", ) + parser.add_argument("--openvino", action="store_true") + parser.add_argument( + "--openvino_device", + type=str, + default="CPU", + choices=["CPU", "GPU", "NPU"], + help="Specify the device for Openvino (CPU, GPU or NPU).", + ) parser.add_argument( "--expand_rope_table", @@ -781,6 +793,14 @@ def get_quantizer_and_quant_params(llm_config): llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) + if llm_config.backend.openvino.enabled and llm_config.quantization.pt2e_quantize: + assert not quantizers, "Should not enable both xnnpack and openvino" + group_size = llm_config.quantization.group_size + group_size = group_size if group_size else 128 + ov_quantizer = get_ov_quantizer( + llm_config.quantization.pt2e_quantize.value, group_size + ) + quantizers.append(ov_quantizer) if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" coreml_quantizer = get_coreml_quantizer( @@ -887,6 +907,34 @@ def _to_edge_and_lower_llama_xnnpack( return builder.to_executorch(passes=additional_passes) +def _to_edge_and_lower_llama_openvino( + builder_exported, + modelname, + quantizers, + additional_passes, + openvino_device: str = "CPU", + verbose: bool = False, +) -> LLMEdgeManager: # noqa: C901 + partitioners = [] + + # Add OpenVINO partitioner + partitioners.append(get_openvino_partitioner(openvino_device)) + modelname = f"openvino_{modelname}" + + logging.info("Lowering model using following partitioner(s): ") + for partitioner in partitioners: + logging.info(f"--> {partitioner.__class__.__name__}") + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners + ) + + if verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + def _to_edge_and_lower_llama( # noqa: C901 builder_exported, modelname, @@ -1131,6 +1179,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 generate_etrecord=llm_config.debug.generate_etrecord, verbose=llm_config.debug.verbose, ) + elif llm_config.backend.openvino.enabled: + builder = _to_edge_and_lower_llama_openvino( + builder_exported, + modelname, + quantizers, + additional_passes, + openvino_device=llm_config.backend.openvino.device, + verbose=llm_config.debug.verbose, + ) else: builder = _to_edge_and_lower_llama( builder_exported, diff --git a/examples/openvino/README.md b/examples/openvino/README.md index 8856ccdce4e..83e3daf6849 100644 --- a/examples/openvino/README.md +++ b/examples/openvino/README.md @@ -9,7 +9,10 @@ Below is the layout of the `examples/openvino` directory, which includes the nec ``` examples/openvino ├── README.md # Documentation for examples (this file) -└── aot_optimize_and_infer.py # Example script to export and execute models +├── aot_optimize_and_infer.py # Example script to export and execute models +└── llama + ├── README.md # Documentation for Llama example + └── llama3_2_ov_4wo.yaml # Configuration file for exporting Llama3.2 with OpenVINO backend ``` # Build Instructions for Examples @@ -154,7 +157,7 @@ Build the backend libraries and executor runner by executing the script below in ```bash ./openvino_build.sh ``` -The executable is saved in `/cmake-out/backends/openvino/` +The executable is saved in `/cmake-out/` ### Run the Example with Executor Runner @@ -163,9 +166,9 @@ Now, run the example using the executable generated in the above step. The execu #### Command Syntax: ``` -cd ../../cmake-out/backends/openvino +cd ../../cmake-out -./openvino_executor_runner \ +./executor_runner \ --model_path= \ --num_executions= ``` @@ -179,7 +182,7 @@ cd ../../cmake-out/backends/openvino Run inference with a given model for 10 iterations: ``` -./openvino_executor_runner \ +./executor_runner \ --model_path=model.pte \ --num_executions=10 ``` diff --git a/examples/openvino/llama/README.md b/examples/openvino/llama/README.md new file mode 100644 index 00000000000..a98645b3918 --- /dev/null +++ b/examples/openvino/llama/README.md @@ -0,0 +1,45 @@ + +# Export Llama with OpenVINO Backend + +## Download the Model +Follow the [instructions](../../../examples/models/llama/README.md#step-2-prepare-model) to download the required model files. Export Llama with OpenVINO backend is only verified with Llama-3.2-1B variants at this time. + +## Environment Setup +Follow the [instructions](../../../backends/openvino/README.md) of **Prerequisites** and **Setup** in `backends/openvino/README.md` to set up the OpenVINO backend. + +## Export the model: +Navigate into `/examples/openvino/llama` and execute the commands below to export the model. Update the model file paths to match the location where your model is downloaded. Replace device with the target hardware you want to compile the model for (`CPU`, `GPU`, or `NPU`). The exported model will be generated in the same directory with the filename `llama3_2_ov.pte`. For modifying the output name, change `output_name` in `llama3_2_ov_4wo.yaml` file under `export`. + +``` +LLAMA_CHECKPOINT=/consolidated.00.pth +LLAMA_PARAMS=/params.json +LLAMA_TOKENIZER=/tokenizer.model + +python -m executorch.extension.llm.export.export_llm \ + --config llama3_2_ov_4wo.yaml \ + +backend.openvino.device="CPU" \ + +base.model_class="llama3_2" \ + +base.checkpoint="${LLAMA_CHECKPOINT:?}" \ + +base.params="${LLAMA_PARAMS:?}" \ + +base.tokenizer_path="${LLAMA_TOKENIZER:?}" +``` + +### Compress Model Weights and Export +OpenVINO backend also offers Quantization support for llama models when exporting the model. The different quantization modes that are offered are INT4 groupwise & per-channel weights compression and INT8 per-channel weights compression. It can be achieved by setting `pt2e_quantize` option in `llama3_2_ov_4wo.yaml` file under `quantization`. Set this parameter to `openvino_4wo` for INT4 or `openvino_8wo` for INT8 weight compression. It is set to `openvino_4wo` in `llama3_2_ov_4wo.yaml` file by default. For modifying the group size, set `group_size` option in `llama3_2_ov_4wo.yaml` file under `quantization`. By default group size 128 is used to achieve optimal performance with the NPU. + +## Build OpenVINO C++ Runtime with Llama Runner: +First, build the backend libraries by executing the script below in `/backends/openvino/scripts` folder: +```bash +./openvino_build.sh --cpp_runtime +``` +Then, build the llama runner by executing the script below (with `--llama_runner` argument) also in `/backends/openvino/scripts` folder: +```bash +./openvino_build.sh --llama_runner +``` +The executable is saved in `/cmake-out/examples/models/llama/llama_main` + +## Execute Inference Using Llama Runner +Update the model tokenizer file path to match the location where your model is downloaded and replace the prompt. +``` +./cmake-out/examples/models/llama/llama_main --model_path=/examples/openvino/llama/llama3_2.pte --tokenizer_path=/tokenizer.model --prompt="Your custom prompt" +``` diff --git a/examples/openvino/llama/llama3_2_ov_4wo.yaml b/examples/openvino/llama/llama3_2_ov_4wo.yaml new file mode 100644 index 00000000000..8fb1d7a1c09 --- /dev/null +++ b/examples/openvino/llama/llama3_2_ov_4wo.yaml @@ -0,0 +1,21 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_kv_cache: true + dtype_override: fp32 + enable_dynamic_shape: false + +quantization: + pt2e_quantize: "openvino_4wo" + group_size: 128 + +export: + output_name: "llama3_2_ov.pte" + +backend: + openvino: + enabled: true + +debug: + verbose: false diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index f15aad9e000..223e5335994 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -281,6 +281,8 @@ class Pt2eQuantize(str, Enum): xnnpack_dynamic = "xnnpack_dynamic" xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4" + openvino_4wo = "openvino_4wo" + openvino_8wo = "openvino_8wo" qnn_8a8w = "qnn_8a8w" qnn_16a16w = "qnn_16a16w" qnn_16a4w = "qnn_16a4w" @@ -454,6 +456,18 @@ class MPSConfig: enabled: bool = False +@dataclass +class OpenvinoConfig: + """ + Configures the QNN backend. + """ + + enabled: bool = False + device: str = "CPU" + nncf_compression: bool = False + nncf_compression_group_size: int = 32 + + @dataclass class TorchAOKernelsConfig: """ @@ -476,6 +490,7 @@ class BackendConfig: vulkan: VulkanConfig = field(default_factory=VulkanConfig) qnn: QNNConfig = field(default_factory=QNNConfig) mps: MPSConfig = field(default_factory=MPSConfig) + openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig) torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) @@ -647,6 +662,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "mps"): llm_config.backend.mps.enabled = args.mps + # Openvino + if hasattr(args, "openvino"): + llm_config.backend.openvino.enabled = args.openvino + if hasattr(args, "openvino_device"): + llm_config.backend.openvino.device = args.openvino_device + if hasattr(args, "nncf_compression"): + llm_config.backend.openvino.nncf_compression = args.nncf_compression + if hasattr(args, "group_size") and args.group_size: + llm_config.backend.openvino.nncf_compression_group_size = args.group_size + # TorchAoKernels if any( hasattr(args, a) diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 5fe220f7dd9..03ac2bd91e4 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -68,6 +68,19 @@ def get_mps_partitioner(use_kv_cache: bool = False): return MPSPartitioner(compile_specs) # pyre-fixme[16] +def get_openvino_partitioner(device: str): + try: + from executorch.backends.openvino.partitioner import OpenvinoPartitioner + from executorch.exir.backend.backend_details import CompileSpec + except ImportError: + raise ImportError( + "Please install the OpenVINO backend following https://github.com/pytorch/executorch/tree/main/backends/openvino" + ) + + compile_specs = [CompileSpec("device", device.encode())] + return OpenvinoPartitioner(compile_specs) + + def get_coreml_partitioner( ios: int = 15, embedding_quantize: Optional[str] = None, diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 2d87c86d113..592a6666dfa 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -215,6 +215,47 @@ def get_qnn_quantizer( return qnn_quantizer, quant_dtype +def get_ov_quantizer( + pt2e_quantize: str, + group_size: int = 128, +): + try: + from executorch.backends.openvino.quantizer import ( + OpenVINOQuantizer, + QuantizationMode, + ) + except ImportError: + raise ImportError("Please install nncf via backends/openvino/requirements.txt") + + backend, quant_config = pt2e_quantize.split("_") + assert ( + backend == "openvino" + ), f"The quantization config is for backend {backend} instead of openvino." + assert ( + group_size + ), "Group Size None is Not Supported. It should be set to -1 for per-channel." + + quantization_params = {} + + if quant_config == "4wo": + quantization_params["mode"] = QuantizationMode.INT4WO_SYM + quantization_params["group_size"] = group_size + quantization_params["ratio"] = 1 + + elif quant_config == "8wo": + quantization_params["mode"] = QuantizationMode.INT8WO_ASYM + quantization_params["group_size"] = -1 + quantization_params["ratio"] = None + + else: + raise AssertionError( + f"No support for quant type {quant_config}. Support 8a4w, 8a8w only." + ) + ov_quantizer = OpenVINOQuantizer(**quantization_params) + + return ov_quantizer + + def get_coreml_quantizer(pt2e_quantize: str): try: from coremltools.optimize.torch.quantization.quantization_config import ( diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 3df8e947459..ba18aede63e 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -85,6 +85,7 @@ set(optional_lib_list quantized_kernels quantized_ops_lib quantized_ops_aot_lib + openvino_backend torchao_ops_executorch torchao_kernels_aarch64 )