diff --git a/backends/openvino/__init__.py b/backends/openvino/__init__.py index dac275d3f12..4a69f6b75ff 100644 --- a/backends/openvino/__init__.py +++ b/backends/openvino/__init__.py @@ -1,4 +1,5 @@ from .partitioner import OpenvinoPartitioner from .preprocess import OpenvinoBackend +from .quantizer.quantizer import OpenVINOQuantizer -__all__ = [OpenvinoBackend, OpenvinoPartitioner] +__all__ = [OpenvinoBackend, OpenvinoPartitioner, OpenVINOQuantizer] diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py new file mode 100644 index 00000000000..03ea98e2c5b --- /dev/null +++ b/backends/openvino/quantizer/__init__.py @@ -0,0 +1,3 @@ +from .quantizer import OpenVINOQuantizer + +__all__ = [OpenVINOQuantizer] diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py new file mode 100644 index 00000000000..480faeee635 --- /dev/null +++ b/backends/openvino/quantizer/quantizer.py @@ -0,0 +1,318 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +from collections import defaultdict +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch.fx +from torch.ao.quantization.observer import HistogramObserver +from torch.ao.quantization.observer import PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import EdgeOrNode +from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase +from torch.ao.quantization.quantizer.quantizer import Quantizer +from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec + +import nncf +import nncf.common.quantization as quantization +import nncf.experimental.torch.fx as nncf_fx +from nncf.common.graph.graph import NNCFGraph + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +class QuantizationMode(Enum): + """ + Defines special quantization modes. + + - INT8_SYM: INT8 symmetric quantization for both activations and weights. + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + """ + + INT8_SYM = "int8_sym" + INT8_MIXED = "int8_mixed" + INT8_TRANSFORMER = "int8_transformer" + + +class OpenVINOQuantizer(Quantizer): + """ + Implementation of the Torch AO quantizer which annotates models with quantization annotations + optimally for the inference via OpenVINO. + """ + + def __init__( + self, + *, + mode: Optional[QuantizationMode] = QuantizationMode.INT8_SYM, + **kwargs, + ): + """ + :param mode: Defines special quantization modes. + - INT8_SYM: INT8 symmetric quantization for both activations and weights. + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + Default value is INT8_SYM. + :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm. + """ + if mode == QuantizationMode.INT8_SYM: + preset = quantization.structs.QuantizationPreset.PERFORMANCE + model_type = None + elif mode == QuantizationMode.INT8_MIXED: + preset = quantization.structs.QuantizationPreset.MIXED + model_type = None + else: + preset = None + model_type = nncf.parameters.ModelType.TRANSFORMER + self._min_max_algo = nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( + preset=preset, model_type=model_type, **kwargs + ) + + def set_ignored_scope( + self, + names: Optional[List[str]] = None, + patterns: Optional[List[str]] = None, + types: Optional[List[str]] = None, + subgraphs: Optional[List[Tuple[List[str], List[str]]]] = None, + validate: bool = True, + ) -> None: + """ + Provides an option to specify portions of model to be excluded from compression. + The ignored scope defines model sub-graphs that should be excluded from the quantization process. + + :param names: List of ignored node names. + :param patterns: List of regular expressions that define patterns for names of ignored nodes. + :param types: List of ignored operation types. + :param subgraphs: List of ignored subgraphs. + :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match + in the model graph. + """ + self._min_max_algo.set_ignored_scope( + nncf.IgnoredScope( + names=names or [], + patterns=patterns or [], + types=types or [], + subgraphs=subgraphs or [], + validate=validate, + ) + ) + + def get_nncf_quantization_setup( + self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph + ) -> quantization.quantizer_setup.SingleConfigQuantizerSetup: + self._min_max_algo._set_backend_entity(model) + return self._min_max_algo.find_quantization_setup(model, nncf_graph) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + + graph = model.graph + node_vs_torch_annotation = defaultdict(QuantizationAnnotation) + + for qp in quantization_setup.quantization_points.values(): + edge_or_node, annotation = self._get_edge_or_node_and_annotation( + graph, nncf_graph, qp, node_vs_torch_annotation + ) + qspec = self._get_torch_ao_qspec_from_qp(qp) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + for quantizer_ids in quantization_setup.unified_scale_groups.values(): + + root_quantizer_id = self._get_unified_scales_root_quantizer_id( + nncf_graph, quantizer_ids, quantization_setup + ) + root_qp = quantization_setup.quantization_points[root_quantizer_id] + + if any(root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig for q_id in quantizer_ids): + qps = [quantization_setup.quantization_points[q_id] for q_id in quantizer_ids] + msg = ( + "Different quantization configs are set to one unified scale group:" + f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" + ) + raise nncf.InternalError(msg) + + root_target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, root_qp.insertion_point.target_node_name + ) + root_edge_or_node = self._get_edge_or_node(root_target_node, root_qp, nncf_graph) + + for quantizer_id in quantizer_ids: + if quantizer_id == root_quantizer_id: + continue + + qspec = SharedQuantizationSpec(root_edge_or_node) + qp = quantization_setup.quantization_points[quantizer_id] + edge_or_node, annotation = self._get_edge_or_node_and_annotation( + graph, nncf_graph, qp, node_vs_torch_annotation + ) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + for node, annotation in node_vs_torch_annotation.items(): + assert QUANT_ANNOTATION_KEY not in node.meta + node.meta[QUANT_ANNOTATION_KEY] = annotation + return model + + @staticmethod + def _get_unified_scales_root_quantizer_id( + nncf_graph: NNCFGraph, + quantizer_ids: List[int], + quantizer_setup: quantization.quantizer_setup.SingleConfigQuantizerSetup, + ) -> int: + """ + Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id` + in the given NNCFGraph. This is required by the `_get_obs_or_fq_map` function. + Refer to: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/prepare.py#L291 + + :param nncf_graph: The NNCFGraph instance. + :param quantizer_ids: The list of quantizer IDs to evaluate. + :param quantizer_setup: The instance of SingleConfigQuantizerSetup. + :return: The ID of the earliest quantizer node in terms of `nncf_node.node_id`. + """ + nncf_node_quantizer_id = None + root_quantizer_id = None + for quantizer_id in quantizer_ids: + target_node_name = quantizer_setup.quantization_points[quantizer_id].insertion_point.target_node_name + nncf_node = nncf_graph.get_node_by_name(target_node_name) + if nncf_node_quantizer_id is None or nncf_node.node_id < nncf_node_quantizer_id: + root_quantizer_id = quantizer_id + nncf_node_quantizer_id = nncf_node.node_id + return root_quantizer_id + + @staticmethod + def _get_edge_or_node_and_annotation( + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + qp: quantization.quantizer_setup.QuantizationPointBase, + node_vs_torch_annotation: Dict[torch.fx.Node, QuantizationAnnotation], + ) -> Tuple[EdgeOrNode, QuantizationAnnotation]: + """ + Retrieves the edge or node and its corresponding QuantizationAnnotation based on the given graph, + quantization point, and node-to-annotation mapping. + + :param graph: torch.fx.Graph instance. + :param nncf_graph: NNCFGraph instance. + :param qp: QuantizationPointBase instance. + :param node_vs_torch_annotation: A dictionary mapping torch.fx.GraphNode objects to their respective + QuantizationAnnotations. + :return: A tuple containing the EdgeOrNode and its associated QuantizationAnnotation. + """ + target_node = nncf_fx.node_utils.get_graph_node_by_name(graph, qp.insertion_point.target_node_name) + annotation = node_vs_torch_annotation[target_node] + edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) + return edge_or_node, annotation + + @staticmethod + def _get_edge_or_node( + target_node: torch.fx.Node, qp: quantization.quantizer_setup.QuantizationPointBase, nncf_graph: NNCFGraph + ) -> EdgeOrNode: + """ + Returns the edge or node based on the given target node and quantization point. + + :param target_node: Target node instance. + :param qp: QuantizationPointBase instance. + :param graph: NNCFGraph instance. + :return: The corresponding EdgeOrNode derived from the target node and quantization point. + """ + ip = qp.insertion_point + if qp.is_weight_quantization_point(): + nncf_node = nncf_graph.get_node_by_name(target_node.name) + weights_ports_ids = nncf.torch.model_graph_manager.get_weight_tensor_port_ids(nncf_node, nncf_graph) + if len(weights_ports_ids) > 1: + # TODO(dlyakhov): support quantization for nodes with several weights + nncf.common.logging.nncf_logger.warning( + f"Quantization of the weighted node {target_node.name}" + " is not yet supported by the OpenVINOQuantizer." + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." + f" Quantizable weights are located on ports: {weights_ports_ids}." + ) + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] + return (weight_node, target_node) + + if ip.input_port_id is None: + return target_node + + node = target_node.all_input_nodes[ip.input_port_id] + return (node, target_node) + + @staticmethod + def _fill_torch_ao_annotation( + edge_or_node: EdgeOrNode, + qspec: QuantizationSpecBase, + annotation_to_update: QuantizationAnnotation, + ) -> None: + """ + Helper method to update the annotation_to_update based on the specified edge_or_node and qspec. + + :param edge_or_node: The target EdgeOrNode to be used for the update. + :param qspec: An instance of QuantizationSpecBase representing the quantization specification to apply. + :param annotation_to_update: The annotation to update based on the edge_or_node and qspec. + """ + if isinstance(edge_or_node, torch.fx.Node): + annotation_to_update.output_qspec = qspec + else: + annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec + + @staticmethod + def _get_torch_ao_qspec_from_qp(qp: quantization.quantizer_setup.QuantizationPointBase) -> QuantizationSpec: + """ + Retrieves the quantization configuration from the given quantization point and + converts it into a QuantizationSpec. + + :param qp: An instance of QuantizationPointBase. + :return: A QuantizationSpec retrieved and converted from the quantization point. + """ + # Eps value is copied from nncf/torch/quantization/layers.py + extra_args = {"eps": 1e-16} + qconfig = qp.qconfig + is_weight = qp.is_weight_quantization_point() + + if qconfig.per_channel: + torch_qscheme = ( + torch.per_channel_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_channel_affine + ) + else: + torch_qscheme = ( + torch.per_tensor_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_tensor_affine + ) + if is_weight: + observer = PerChannelMinMaxObserver + quant_min = -128 + quant_max = 127 + dtype = torch.int8 + channel_axis = 0 + else: + observer = ( + HistogramObserver + if torch_qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + else PerChannelMinMaxObserver + ) + quant_min = 0 + quant_max = 255 + dtype = torch.int8 if qconfig.signedness_to_force else torch.uint8 + channel_axis = 1 # channel dim for activations + return QuantizationSpec( + dtype=dtype, + observer_or_fake_quant_ctr=observer.with_args(**extra_args), + quant_min=quant_min, + quant_max=quant_max, + qscheme=torch_qscheme, + ch_axis=channel_axis, + is_dynamic=False, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_fx.transformations.fold_constant_except_qdq(model) + return model diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt index 7c3de886e27..f00257127a3 100644 --- a/backends/openvino/requirements.txt +++ b/backends/openvino/requirements.txt @@ -6,3 +6,4 @@ tokenizers transformers piq pillow +nncf @ https://github.com/openvinotoolkit/nncf.git diff --git a/examples/openvino/CMakeLists.txt b/examples/openvino/CMakeLists.txt index 4a1917fa3af..10638a7b5f7 100644 --- a/examples/openvino/CMakeLists.txt +++ b/examples/openvino/CMakeLists.txt @@ -55,6 +55,7 @@ target_include_directories(openvino_portable_ops_lib PUBLIC ${_common_include_di # Build Executor Runner add_executable(openvino_executor_runner ${_openvino_executor_runner__srcs}) + target_include_directories( openvino_executor_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_ROOT}/cmake-openvino-out/third-party/gflags/include ) diff --git a/examples/openvino/aot/README.md b/examples/openvino/aot/README.md index 6c59f1dad41..884ed55849f 100644 --- a/examples/openvino/aot/README.md +++ b/examples/openvino/aot/README.md @@ -11,30 +11,50 @@ python aot_openvino_compiler.py --suite --model --inp ``` ### **Arguments** -- **`--suite`** (required): - Specifies the model suite to use. +- **`--suite`** (required): + Specifies the model suite to use. Supported values: - `timm` (e.g., VGG16, ResNet50) - `torchvision` (e.g., resnet18, mobilenet_v2) - - `huggingface` (e.g., bert-base-uncased) + - `huggingface` (e.g., bert-base-uncased). NB: Quantization and validation is not supported yet. -- **`--model`** (required): - Name of the model to export. +- **`--model`** (required): + Name of the model to export. Examples: - For `timm`: `vgg16`, `resnet50` - For `torchvision`: `resnet18`, `mobilenet_v2` - For `huggingface`: `bert-base-uncased`, `distilbert-base-uncased` -- **`--input_shape`** (required): - Input shape for the model. Provide this as a **list** or **tuple**. +- **`--input_shape`**(optional): + Input shape for the model. Provide this as a **list** or **tuple**. Examples: - `[1, 3, 224, 224]` (Zsh users: wrap in quotes) - `(1, 3, 224, 224)` -- **`--device`** (optional): - Target device for the compiled model. Default is `CPU`. +- **`--batch_size`** : + Batch size for the validation. Default batch_size == 1. + The dataset length must be evenly divisible by the batch size. + +- **`--quantize`** (optional): + Enable model quantization. --dataset argument is requred for the quantization. `huggingface` suite does not supported yet. + +- **`--quantization_flow`** (optional): + Specifies the way to quantize torch.fx.GraphModule. + Supported values: + - `nncf`: `nncf quantize_pt2e` API (default) + - `pt2e`: torch ao quantization pipeline. + +- **`--validate`** (optional): + Enable model validation. --dataset argument is requred for the validation. `huggingface` suite does not supported yet. + +- **`--dataset`** (optional): + Path to the imagenet-like calibration dataset. + +- **`--device`** (optional) + Target device for the compiled model. Default is `CPU`. Examples: `CPU`, `GPU` + ## **Examples** ### Export a TIMM VGG16 model for the CPU @@ -51,22 +71,31 @@ python aot_openvino_compiler.py --suite torchvision --model resnet50 --input_sha ```bash python aot_openvino_compiler.py --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU ``` +### Export and validate TIMM Resnet50d model for the CPU +```bash +python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset +``` + +### Export, quantize and validate TIMM Resnet50d model for the CPU +```bash +python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset --quantize +``` ## **Notes** -1. **Input Shape in Zsh**: +1. **Input Shape in Zsh**: If you are using Zsh, wrap `--input_shape` in quotes or use a tuple: ```bash --input_shape '[1, 3, 224, 224]' --input_shape "(1, 3, 224, 224)" ``` -2. **Model Compatibility**: +2. **Model Compatibility**: Ensure the specified `model_name` exists in the selected `suite`. Use the corresponding library's documentation to verify model availability. -3. **Output File**: +3. **Output File**: The exported model will be saved as `.pte` in the current directory. -4. **Dependencies**: +4. **Dependencies**: - Python 3.8+ - PyTorch - Executorch @@ -75,14 +104,14 @@ python aot_openvino_compiler.py --suite huggingface --model bert-base-uncased -- - Transformers (`pip install transformers`) ## **Error Handling** -- **Model Not Found**: +- **Model Not Found**: If the script raises an error such as: ```bash ValueError: Model not found ``` Verify that the model name is correct for the chosen suite. -- **Unsupported Input Shape**: +- **Unsupported Input Shape**: Ensure `--input_shape` is provided as a valid list or tuple. diff --git a/examples/openvino/aot/aot_openvino_compiler.py b/examples/openvino/aot/aot_openvino_compiler.py index 4674fbbd755..f0844289580 100644 --- a/examples/openvino/aot/aot_openvino_compiler.py +++ b/examples/openvino/aot/aot_openvino_compiler.py @@ -4,18 +4,36 @@ # except in compliance with the License. See the license file in the root # directory of this source tree for more details. +import argparse +import os +import shutil +import subprocess +from itertools import islice +from pathlib import Path + import executorch +import numpy as np import timm import torch +import torchvision.datasets as datasets import torchvision.models as torchvision_models -from transformers import AutoModel -from executorch.exir.backend.backend_details import CompileSpec -from executorch.backends.openvino.preprocess import OpenvinoBackend +from executorch.backends.openvino import OpenVINOQuantizer from executorch.backends.openvino.partitioner import OpenvinoPartitioner -from executorch.exir import EdgeProgramManager, to_edge -from torch.export import export, ExportedProgram +from executorch.exir import EdgeProgramManager +from executorch.exir import to_edge +from executorch.exir.backend.backend_details import CompileSpec +from sklearn.metrics import accuracy_score +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.export import export from torch.export.exported_program import ExportedProgram -import argparse +from transformers import AutoModel + +import nncf +from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e + # Function to load a model based on the selected suite def load_model(suite: str, model_name: str): @@ -23,30 +41,166 @@ def load_model(suite: str, model_name: str): return timm.create_model(model_name, pretrained=True) elif suite == "torchvision": if not hasattr(torchvision_models, model_name): - raise ValueError(f"Model {model_name} not found in torchvision.") + msg = f"Model {model_name} not found in torchvision." + raise ValueError(msg) return getattr(torchvision_models, model_name)(pretrained=True) elif suite == "huggingface": return AutoModel.from_pretrained(model_name) else: - raise ValueError(f"Unsupported model suite: {suite}") + msg = f"Unsupported model suite: {suite}" + raise ValueError(msg) -def main(suite: str, model_name: str, input_shape, device: str): - # Ensure input_shape is a tuple - if isinstance(input_shape, list): - input_shape = tuple(input_shape) - elif not isinstance(input_shape, tuple): - raise ValueError("Input shape must be a list or tuple.") +def load_calibration_dataset(dataset_path: str, batch_size: int, suite: str, model: torch.nn.Module, model_name: str): + val_dir = f"{dataset_path}/val" + + if suite == "torchvision": + transform = torchvision_models.get_model_weights(model_name).DEFAULT.transforms() + elif suite == "timm": + transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) + else: + msg = f"Validation is not supported yet for the suite {suite}" + raise ValueError(msg) + + val_dataset = datasets.ImageFolder(val_dir, transform=transform) + + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True + ) + + return calibration_dataset + + +def dump_inputs(calibration_dataset, dest_path): + input_files, targets = [], [] + for idx, data in enumerate(calibration_dataset): + feature, target = data + targets.extend(target) + file_name = f"{dest_path}/input_{idx}_0.raw" + if not isinstance(feature, torch.Tensor): + feature = torch.tensor(feature) + feature.detach().numpy().tofile(file_name) + input_files.append(file_name) + + return input_files, targets + + +def quantize_model( + captured_model: torch.fx.GraphModule, calibration_dataset: torch.utils.data.DataLoader, use_nncf: bool +) -> torch.fx.GraphModule: + quantizer = OpenVINOQuantizer() + + print("PTQ: Quantize the model") + default_subset_size = 300 + batch_size = calibration_dataset.batch_size + subset_size = (default_subset_size // batch_size) + int(default_subset_size % batch_size > 0) + + def transform(x): + return x[0] + + if use_nncf: + + quantized_model = quantize_pt2e( + captured_model, + quantizer, + subset_size=subset_size, + calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform), + fold_quantize=False, + ) + else: + annotated_model = prepare_pt2e(captured_model, quantizer) + + print("PTQ: Calibrate the model...") + for data in islice(calibration_dataset, subset_size): + annotated_model(transform(data)) + + print("PTQ: Convert the quantized model...") + quantized_model = convert_pt2e(annotated_model, fold_quantize=False) + + return quantized_model + + +def validate_model(model_file_name: str, calibration_dataset: torch.utils.data.DataLoader) -> float: + # 1: Dump inputs + dest_path = Path("tmp_inputs") + out_path = Path("tmp_outputs") + for d in [dest_path, out_path]: + if os.path.exists(d): + shutil.rmtree(d) + os.makedirs(d) + + input_files, targets = dump_inputs(calibration_dataset, dest_path) + inp_list_file = dest_path / "in_list.txt" + with open(inp_list_file, "w") as f: + f.write("\n".join(input_files) + "\n") + + # 2: Run the executor + print("Run openvino_executor_runner...") + + subprocess.run( + [ + "../../../cmake-openvino-out/examples/openvino/openvino_executor_runner", + f"--model_path={model_file_name}", + f"--input_list_path={inp_list_file}", + f"--output_folder_path={out_path}", + ] + ) + + # 3: load the outputs and compare with the targets + predictions = [] + for i in range(len(input_files)): + tensor = np.fromfile(out_path / f"output_{i}_0.raw", dtype=np.float32) + predictions.extend(torch.tensor(tensor).reshape(-1, 1000).argmax(-1)) + + return accuracy_score(predictions, targets) + + +def main( + suite: str, + model_name: str, + input_shape, + quantize: bool, + validate: bool, + dataset_path: str, + device: str, + batch_size: int, + quantization_flow: str, +): # Load the selected model model = load_model(suite, model_name) model = model.eval() + if dataset_path: + calibration_dataset = load_calibration_dataset(dataset_path, batch_size, suite, model, model_name) + input_shape = tuple(next(iter(calibration_dataset))[0].shape) + print(f"Input shape retrieved from the model config: {input_shape}") + # Ensure input_shape is a tuple + elif isinstance(input_shape, (list, tuple)): + input_shape = tuple(input_shape) + else: + msg = "Input shape must be a list or tuple." + raise ValueError(msg) # Provide input - example_args = (torch.randn(*input_shape), ) + example_args = (torch.randn(*input_shape),) - # Export to aten dialect using torch.export + # Export the model to the aten dialect aten_dialect: ExportedProgram = export(model, example_args) + if quantize: + if suite == "huggingface": + msg = f"Quantization of {suite} models did not support yet." + raise ValueError(msg) + + # Quantize model + if not dataset_path: + msg = "Quantization requires a calibration dataset." + raise ValueError(msg) + quantized_model = quantize_model( + aten_dialect.module(), calibration_dataset, use_nncf=quantization_flow == "nncf" + ) + + aten_dialect: ExportedProgram = export(quantized_model, example_args) + # Convert to edge dialect edge_program: EdgeProgramManager = to_edge(aten_dialect) to_be_lowered_module = edge_program.exported_program() @@ -59,22 +213,84 @@ def main(suite: str, model_name: str, input_shape, device: str): exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig()) # Serialize and save it to a file - with open(f"{model_name}.pte", "wb") as file: + model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte" + with open(model_file_name, "wb") as file: exec_prog.write_to_file(file) - print(f"Model exported and saved as {model_name}.pte on {device}.") + print(f"Model exported and saved as {model_file_name} on {device}.") + + if validate: + if suite == "huggingface": + msg = f"Validation of {suite} models did not support yet." + raise ValueError(msg) + + if not dataset_path: + msg = "Validation requires a calibration dataset." + raise ValueError(msg) + + print("Start validation of the model:") + acc_top1 = validate_model(model_file_name, calibration_dataset) + print(f"acc@1: {acc_top1}") + if __name__ == "__main__": # Argument parser for dynamic inputs parser = argparse.ArgumentParser(description="Export models with executorch.") - parser.add_argument("--suite", type=str, required=True, choices=["timm", "torchvision", "huggingface"], - help="Select the model suite (timm, torchvision, huggingface).") + parser.add_argument( + "--suite", + type=str, + required=True, + choices=["timm", "torchvision", "huggingface"], + help="Select the model suite (timm, torchvision, huggingface).", + ) parser.add_argument("--model", type=str, required=True, help="Model name to be loaded.") - parser.add_argument("--input_shape", type=eval, required=True, - help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).") - parser.add_argument("--device", type=str, default="CPU", - help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.") + parser.add_argument( + "--input_shape", + type=eval, + help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for the validation. Default batch_size == 1." + " The dataset length must be evenly divisible by the batch size.", + ) + parser.add_argument("--quantize", action="store_true", help="Enable model quantization.") + parser.add_argument( + "--validate", + action="store_true", + help="Enable model validation. --dataset argument is required for the validation.", + ) + parser.add_argument("--dataset", type=str, help="Path to the validation dataset.") + parser.add_argument( + "--device", + type=str, + default="CPU", + help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.", + ) + parser.add_argument( + "--quantization_flow", + type=str, + choices=["pt2e", "nncf"], + default="nncf", + help="Select the quantization flow (nncf or pt2e):" + " pt2e is the default torch.ao quantization flow, while" + " nncf is a custom method with additional algorithms to improve model performance.", + ) args = parser.parse_args() # Run the main function with parsed arguments - main(args.suite, args.model, args.input_shape, args.device) + # Disable nncf patching as export of the patched model is not supported. + with nncf.torch.disable_patching(): + main( + args.suite, + args.model, + args.input_shape, + args.quantize, + args.validate, + args.dataset, + args.device, + args.batch_size, + args.quantization_flow, + ) diff --git a/examples/openvino/executor_runner/openvino_executor_runner.cpp b/examples/openvino/executor_runner/openvino_executor_runner.cpp index 7615b63649a..c3922c793a3 100644 --- a/examples/openvino/executor_runner/openvino_executor_runner.cpp +++ b/examples/openvino/executor_runner/openvino_executor_runner.cpp @@ -9,8 +9,10 @@ #include #include #include +#include #include #include +#include #include @@ -25,22 +27,16 @@ // Define a fixed-size memory pool for the method allocator (4 MB) static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB -// Define command-line flags for model path, the number of iterations, input list path, and output folder path +// Define command-line flags for model path, the number of iterations, input +// list path, and output folder path +DEFINE_string(model_path, "", + "Path to the model serialized in flatbuffer format (required)."); +DEFINE_int32(num_iter, 1, "Number of inference iterations (default is 1)."); +DEFINE_string(input_list_path, "", + "Path to the input list file which includes the list of raw " + "input tensor files (optional)."); DEFINE_string( - model_path, - "", - "Path to the model serialized in flatbuffer format (required)."); -DEFINE_int32( - num_iter, - 1, - "Number of inference iterations (default is 1)."); -DEFINE_string( - input_list_path, - "", - "Path to the input list file which includes the list of raw input tensor files (optional)."); -DEFINE_string( - output_folder_path, - "", + output_folder_path, "", "Path to the output folder to save raw output tensor files (optional)."); using executorch::extension::FileDataLoader; @@ -57,7 +53,87 @@ using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::TensorInfo; -int main(int argc, char** argv) { +std::function build_set_input_tensor( + Result &method, std::vector &inputs, + const std::vector> input_paths) { + return [&inputs, &method, input_paths](size_t idx) -> void { + const MethodMeta method_meta = method->method_meta(); + for (int input_index = 0; input_index < method->inputs_size(); + ++input_index) { + + Result tensor_meta = + method_meta.input_tensor_meta(input_index); + auto input_data_ptr = inputs[input_index].toTensor().data_ptr(); + + std::ifstream fin(input_paths[idx][input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + input_index, file_size, tensor_meta->nbytes()); + + fin.seekg(0, fin.beg); + fin.read(static_cast(input_data_ptr), file_size); + fin.close(); + } + }; +} + +std::function +build_dump_outputs(std::vector &outputs, const size_t output_size, + const std::string output_folder_path) { + return [&outputs, output_folder_path, output_size](size_t idx) -> void { + for (size_t output_index = 0; output_index < output_size; output_index++) { + auto output_tensor = outputs[output_index].toTensor(); + auto output_file_name = output_folder_path + "/output_" + + std::to_string(idx) + "_" + + std::to_string(output_index) + ".raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write(output_tensor.const_data_ptr(), output_tensor.nbytes()); + fout.close(); + } + }; +} + +std::vector> +get_inputs_paths(const char *input_list_path) { + size_t idx = 0; + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + // Read raw input tensor file names from input list file and + // iterate each raw input tensor file to read values + std::ifstream input_list(input_list_path); + if (!input_list.is_open()) { + ET_CHECK_MSG(false, "Failed to read input list file: %s", input_list_path); + } + std::string file_path; + auto retval = std::vector>(); + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + retval.push_back(input_files); + } + return retval; +} + +int main(int argc, char **argv) { // Initialize the runtime environment executorch::runtime::runtime_init(); @@ -68,22 +144,21 @@ int main(int argc, char** argv) { if (FLAGS_model_path.empty()) { std::cerr << "Error: --model_path is required." << std::endl; std::cerr << "Usage: " << argv[0] - << " --model_path= --num_iter=" << std::endl; + << " --model_path= --num_iter=" + << std::endl; return 1; } // Retrieve the model path and number of iterations - const char* model_path = FLAGS_model_path.c_str(); + const char *model_path = FLAGS_model_path.c_str(); int num_iterations = FLAGS_num_iter; std::cout << "Model path: " << model_path << std::endl; std::cout << "Number of iterations: " << num_iterations << std::endl; // Load the model using FileDataLoader Result loader = FileDataLoader::from(model_path); - ET_CHECK_MSG( - loader.ok(), - "FileDataLoader::from() failed: 0x%" PRIx32, - static_cast(loader.error())); + ET_CHECK_MSG(loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, + static_cast(loader.error())); // Load the program from the loaded model Result program = Program::load(&loader.get()); @@ -93,8 +168,9 @@ int main(int argc, char** argv) { } ET_LOG(Info, "Model file %s is loaded.", model_path); - // Retrieve the method name from the program (assumes the first method is used) - const char* method_name = nullptr; + // Retrieve the method name from the program (assumes the first method is + // used) + const char *method_name = nullptr; { const auto method_name_result = program->get_method_name(0); ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); @@ -104,11 +180,8 @@ int main(int argc, char** argv) { // Retrieve metadata about the method Result method_meta = program->method_meta(method_name); - ET_CHECK_MSG( - method_meta.ok(), - "Failed to get method_meta for %s: 0x%" PRIx32, - method_name, - static_cast(method_meta.error())); + ET_CHECK_MSG(method_meta.ok(), "Failed to get method_meta for %s: 0x%" PRIx32, + method_name, static_cast(method_meta.error())); // Set up a memory allocator for the method MemoryAllocator method_allocator{ @@ -133,135 +206,87 @@ int main(int argc, char** argv) { // Load the method into the program Result method = program->load_method(method_name, &memory_manager); - ET_CHECK_MSG( - method.ok(), - "Loading of method %s failed with status 0x%" PRIx32, - method_name, - static_cast(method.error())); + ET_CHECK_MSG(method.ok(), + "Loading of method %s failed with status 0x%" PRIx32, + method_name, static_cast(method.error())); ET_LOG(Info, "Method loaded."); // Prepare the input tensors for the method - auto inputs = prepare_input_tensors(*method); - ET_CHECK_MSG( - inputs.ok(), - "Could not prepare inputs: 0x%" PRIx32, - static_cast(inputs.error())); - - // If the input path list is provided, read input tensors from the files - if (!(FLAGS_input_list_path.empty())) { - const char* input_list_path = FLAGS_input_list_path.c_str(); - ET_LOG(Info, "Loading input tensors from the list provided in %s.", input_list_path); - Error status = Error::Ok; - std::vector inputs(method->inputs_size()); - ET_LOG(Info, "%zu inputs: ", inputs.size()); - status = method->get_inputs(inputs.data(), inputs.size()); - ET_CHECK(status == Error::Ok); - - auto split = [](std::string s, std::string delimiter) { - size_t pos_start = 0, pos_end, delim_len = delimiter.length(); - std::string token; - std::vector res; - - while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { - token = s.substr(pos_start, pos_end - pos_start); - pos_start = pos_end + delim_len; - res.push_back(token); - } - res.push_back(s.substr(pos_start)); - return res; - }; - - // Read raw input tensor file names from input list file and - // iterate each raw input tensor file to read values - std::ifstream input_list(input_list_path); - if (input_list.is_open()) { - size_t num_inputs = method->inputs_size(); - std::string file_path; - while (std::getline(input_list, file_path)) { - auto input_files = split(file_path, " "); - if (input_files.size() == 0) { - break; - } - for (int input_index = 0; input_index < num_inputs; ++input_index) { - MethodMeta method_meta = method->method_meta(); - Result tensor_meta = - method_meta.input_tensor_meta(input_index); - auto input_data_ptr = inputs[input_index].toTensor().data_ptr(); - - std::ifstream fin(input_files[input_index], std::ios::binary); - fin.seekg(0, fin.end); - size_t file_size = fin.tellg(); - - ET_CHECK_MSG( - file_size == tensor_meta->nbytes(), - "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", - input_index, - file_size, - tensor_meta->nbytes()); - - fin.seekg(0, fin.beg); - fin.read( - static_cast(input_data_ptr), - file_size); - fin.close(); - } - } - } else { - ET_CHECK_MSG(false, - "Failed to read input list file: %s", - input_list_path); - } - } - ET_LOG(Info, "Inputs prepared."); + auto method_inputs = prepare_input_tensors(*method); + ET_CHECK_MSG(method_inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, + static_cast(method_inputs.error())); - // Measure execution time for inference - auto before_exec = std::chrono::high_resolution_clock::now(); Error status = Error::Ok; - for (int i = 0; i < num_iterations; ++i) { - status = method->execute(); + std::vector inputs(method->inputs_size()); + ET_LOG(Info, "Number of input layers: %zu", inputs.size()); + + status = method->get_inputs(inputs.data(), inputs.size()); + ET_CHECK(status == Error::Ok); + + // If the input path list is provided, read input tensors from the files + std::function set_input_tensor; + if (!FLAGS_input_list_path.empty()) { + const char *input_list_path = FLAGS_input_list_path.c_str(); + ET_LOG(Info, "Loading input tensors from the list provided in %s.", + input_list_path); + const auto input_paths = get_inputs_paths(input_list_path); + num_iterations = input_paths.size(); + ET_LOG(Info, "Number of iters is set to the len of the inputs: %u.", + num_iterations); + + set_input_tensor = build_set_input_tensor(method, inputs, input_paths); + } else { + set_input_tensor = [](size_t idx) -> void {}; } - auto after_exec = std::chrono::high_resolution_clock::now(); - double elapsed_time = std::chrono::duration_cast( - after_exec - before_exec) - .count() / 1000.0; - // Log execution time and average time per iteration - ET_LOG( - Info, - "%d inference took %f ms, avg %f ms", - num_iterations, - elapsed_time, - elapsed_time / static_cast(num_iterations)); - ET_CHECK_MSG( - status == Error::Ok, - "Execution of method %s failed with status 0x%" PRIx32, - method_name, - static_cast(status)); - ET_LOG(Info, "Model executed successfully."); + ET_LOG(Info, "%zu Number of output layers: ", method->outputs_size()); - // Retrieve and print the method outputs std::vector outputs(method->outputs_size()); - ET_LOG(Info, "%zu Number of outputs: ", outputs.size()); status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); - // If output folder path is provided, save output tensors - // into raw tensor files. - if (!(FLAGS_output_folder_path.empty())) { - const char* output_folder_path = FLAGS_output_folder_path.c_str(); - ET_LOG(Info, "Saving output tensors into the output folder: %s.", output_folder_path); - for (size_t output_index = 0; output_index < method->outputs_size(); - output_index++) { - auto output_tensor = outputs[output_index].toTensor(); - auto output_file_name = std::string(output_folder_path) + "/output_" + - std::to_string(output_index) + ".raw"; - std::ofstream fout(output_file_name.c_str(), std::ios::binary); - fout.write( - output_tensor.const_data_ptr(), output_tensor.nbytes()); - fout.close(); + std::function dump_outputs; + if (!FLAGS_output_folder_path.empty()) { + // Retrieve and print the method outputs + + // If output folder path is provided, save output tensors + // into raw tensor files. + const char *output_folder_path = FLAGS_output_folder_path.c_str(); + ET_LOG(Info, "Saving output tensors into the output folder: %s.", + output_folder_path); + dump_outputs = build_dump_outputs(outputs, outputs.size(), + std::string(output_folder_path)); + + } else { + dump_outputs = [](size_t idx) {}; + } + + // Measure execution time for inference + + double total_time_elapsed = 0.; + for (int i = 0; (i < num_iterations and status == Error::Ok); ++i) { + set_input_tensor(i); + auto before_exec = std::chrono::high_resolution_clock::now(); + status = method->execute(); + auto after_exec = std::chrono::high_resolution_clock::now(); + if (status == Error::Ok) { + dump_outputs(i); } + double elapsed_time = std::chrono::duration_cast( + after_exec - before_exec) + .count() / + 1000.0; + total_time_elapsed += elapsed_time; } + // Log execution time and average time per iteration + ET_LOG(Info, "%d inference took %f ms, avg %f ms", num_iterations, + total_time_elapsed, + total_time_elapsed / static_cast(num_iterations)); + ET_CHECK_MSG(status == Error::Ok, + "Execution of method %s failed with status 0x%" PRIx32, + method_name, static_cast(status)); + ET_LOG(Info, "Model executed successfully."); + return 0; } -