diff --git a/aa_torch_fx.py b/aa_torch_fx.py new file mode 100644 index 00000000000..339d33d1598 --- /dev/null +++ b/aa_torch_fx.py @@ -0,0 +1,456 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import re +import subprocess +import time +import warnings +from itertools import islice +from pathlib import Path + +import numpy as np +import openvino as ov +import openvino.torch # noqa +import pandas as pd +import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from sklearn.metrics import accuracy_score +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch.jit import TracerWarning +from torchao.utils import benchmark_model as ao_benchmark_model +from torchvision import datasets +from transformers import AutoImageProcessor +from transformers import AutoModelForImageClassification + +import nncf +from nncf.common.logging.track_progress import track +from nncf.common.quantization.structs import QuantizationPreset # noqa +from nncf.parameters import ModelType +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + +warnings.filterwarnings("ignore", category=TracerWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +DATASET_IMAGENET = "/home/dlyakhov/datasets/imagenet/val" + +hf_models = () + + +def hf_model_builder(model_id: str): + def build(weights): + processor = AutoImageProcessor.from_pretrained(model_id) + model = AutoModelForImageClassification.from_pretrained(model_id) + + class ModelWithProcessing(torch.nn.Module): + def __init__(self, processor, model): + super().__init__() + self.processor = processor + self.model = model + + def forward(self, x): + processed_input = processor(x, return_tensors="pt") + return model(processed_input) + + # return ModelWithProcessing(processor, model) + return model + + class DummyWeights: + def transforms(self): + return models.ResNet18_Weights.DEFAULT.transforms() + + @property + def meta(self): + return {} + + return build, DummyWeights() + + +MODELS_DICT = { + "vit_h_14": (models.vit_h_14, models.ViT_H_14_Weights.DEFAULT), + "vit_b_16": (models.vit_b_16, models.ViT_B_16_Weights.DEFAULT), + "swin_v2_t": (models.swin_v2_t, models.Swin_V2_T_Weights.DEFAULT), + "swin_v2_s": (models.swin_v2_s, models.Swin_V2_S_Weights.DEFAULT), + "resnet18": (models.resnet18, models.ResNet18_Weights.DEFAULT), + "resnet50": (models.resnet50, models.ResNet50_Weights.DEFAULT), + "mobilenet_v2": (models.mobilenet_v2, models.MobileNet_V2_Weights.DEFAULT), + "mobilenet_v3_small": (models.mobilenet_v3_small, models.MobileNet_V3_Small_Weights.DEFAULT), + "mobilenet_v3_large": (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT), + # "densenet161": (models.densenet161, models.DenseNet161_Weights.DEFAULT), + "vgg16": (models.vgg16, models.VGG16_Weights.DEFAULT), + "efficientnet_b7": (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), + "inception_v3": (models.inception_v3, models.Inception_V3_Weights.DEFAULT), + "regnet_x_32gf": (models.regnet_x_32gf, models.RegNet_X_32GF_Weights.DEFAULT), + # "google/vit-base-patch16-224": hf_model_builder("google/vit-base-patch16-224"), + # "convnext_large": (models.convnext_large, models.ConvNeXt_Large_Weights.DEFAULT), + # "convnext_small": (models.convnext_small, models.ConvNeXt_Small_Weights.DEFAULT), +} + + +def measure_time(model, example_inputs, num_iters=1000): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def measure_time_ov(model, example_inputs, num_iters=1000): + ie = ov.Core() + compiled_model = ie.compile_model(model, "CPU") + infer_request = compiled_model.create_infer_request() + infer_request.infer(example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + infer_request.infer(example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def quantize(model, example_inputs, calibration_dataset, subset_size=300): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(model, example_inputs) + + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + from tqdm import tqdm + + for inp, _ in islice(tqdm(calibration_dataset), subset_size): + prepared_model(inp) + converted_model = convert_pt2e(prepared_model) + return converted_model + + +def validate(model, val_loader, subset_size=None): + dataset_size = len(val_loader) + + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + with track(total=dataset_size, description="Validation") as pbar: + + for i, (images, target) in enumerate(val_loader): + if subset_size is not None and i >= subset_size: + break + + output_data = model(images).detach().numpy() + predicted_label = np.argmax(output_data, axis=1) + predictions[i] = predicted_label.item() + references[i] = target + pbar.progress.update(pbar.task, advance=1) + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def validate_ov(model, val_loader): + dataset_size = len(val_loader) + + # Initialize result tensors for async inference support. + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + core = ov.Core() + compiled_model = core.compile_model(model) + + infer_queue = ov.AsyncInferQueue(compiled_model, 4) + with track(total=dataset_size, description="Validation") as pbar: + + def process_result(request, userdata): + output_data = request.get_output_tensor().data + predicted_label = np.argmax(output_data, axis=1) + predictions[userdata] = predicted_label.item() + pbar.progress.update(pbar.task, advance=1) + + infer_queue.set_callback(process_result) + + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + image_copies = copy.deepcopy(images.numpy()) + infer_queue.start_async(image_copies, userdata=i) + references[i] = target + + infer_queue.wait_all() + + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def run_benchmark(model_path: Path, shape) -> float: + command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" + command += f' -shape="[{",".join(str(x) for x in shape)}]"' + cmd_output = subprocess.check_output(command, shell=True) # nosec + match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) + return float(match.group(1)) + + +def torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + import torch + from torchao.quantization.smoothquant import smooth_fq_linear_to_inference + from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear + + # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor + torch._inductor.config.force_fuse_int_mm_with_mul = True + + # plug in your model + # model = torch.compile(pt_model) + model = pt_model + + # convert linear modules to smoothquant + # linear module in calibration mode + swap_linear_with_smooth_fq_linear(model) + + # Create a data loader for calibration + calibration_loader = val_loader + + # Calibrate the model + model.train() + from tqdm import tqdm + + for batch in tqdm(islice(calibration_loader, 300)): + inputs = batch[0] + model(inputs) + + # set it to inference mode + smooth_fq_linear_to_inference(model) + + # compile the model to improve performance + model = torch.compile(model, mode="max-autotune") + acc1_quant_model = validate(model, val_loader) + print(f"torch ao metric acc@1: {acc1_quant_model}") + result["torch_ao_quant_model_acc"] = acc1_quant_model + + latency = ao_benchmark_model(model, 20, example_input) + print(f"torch ao latency: {latency}") + result["torch_ao_quant_model_latency"] = latency + + +def nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + + def transform(x): + return x[0] + + quant_fx_model = nncf.quantize( + exported_model, nncf.Dataset(val_loader, transform_func=transform), model_type=ModelType.TRANSFORMER + ) + quant_compile_model = torch.compile(quant_fx_model, backend="openvino") + + # acc1_quant_model = validate(quant_compile_model, val_loader) + acc1_quant_model = -1.0 + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["acc1_nncf_fx_quant_model"] = acc1_quant_model + result["torch_compile_ov_latency_nncf_fx_quant_model"] = latency_fx + + g = FxGraphDrawer(quant_compile_model, f"b_nncf_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_nncf_{pt_model.__class__.__name__}_int8.svg") + + # EXPORT TO OV + exported_model = torch.export.export(quant_compile_model, (example_input,)) + ov_quant_model = ov.convert_model(exported_model, example_input=example_input) + quant_file_path = output_dir / "quant.xml" + ov.save_model(ov_quant_model, quant_file_path) + + fps = run_benchmark(quant_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_fx_quant_model"] = fps + + +def fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + fp32_pt_model = copy.deepcopy(pt_model) + fp32_compile_model = torch.compile(fp32_pt_model, backend="openvino") + + quant_pt_model = quantize(fp32_compile_model, (example_input,), val_loader) + quant_compile_model = torch.compile(quant_pt_model, backend="openvino") + + g = FxGraphDrawer(quant_pt_model, f"b_pt_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_pt_{pt_model.__class__.__name__}_int8.svg") + + acc1_quant_model = validate(quant_compile_model, val_loader) + result["acc1_quant_model"] = acc1_quant_model + + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["torch_compile_latency_fps_quant_model"] = latency_fx + + +def nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input): + def transform(x): + return x[0] + + nncf_model = nncf.quantize(copy.deepcopy(pt_model), nncf.Dataset(val_loader, transform_func=transform)) + + ov_nncf_model = ov.convert_model(nncf_model, example_input=example_input) + nncf_pt_file_path = output_dir / "nncf_pt.xml" + ov.save_model(ov_nncf_model, nncf_pt_file_path) + acc1_nncf_pt = validate_ov(ov_nncf_model, val_loader) + result["acc1_nncf_pt"] = acc1_nncf_pt + fps = run_benchmark(nncf_pt_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_pt"] = fps + + +def nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input): + def transform(x): + return np.array(x[0]) + + from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters + from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters + + advanced_params = AdvancedQuantizationParameters() + # for sq_param in [-1, 0.15, 0.5, 0.75]: + for sq_param in [0.95]: + advanced_params.smooth_quant_alphas = AdvancedSmoothQuantParameters(matmul=sq_param) + + from copy import deepcopy + + fast_bias_correction = True + nncf_ov_int8_model = nncf.quantize( + deepcopy(ov_fp32_model), + nncf.Dataset(val_loader, transform_func=transform), + fast_bias_correction=fast_bias_correction, + model_type=ModelType.TRANSFORMER, + preset=QuantizationPreset.MIXED, + advanced_parameters=advanced_params, + ) + acc1_nncf_ov = validate_ov(nncf_ov_int8_model, val_loader) + result[f"acc1_nncf_ov_{sq_param}"] = acc1_nncf_ov + for precision, model in (("int8", nncf_ov_int8_model), ("fp32", ov_fp32_model)): + nncf_ov_file_path = output_dir / f"nncf_ov_{precision}.xml" + ov.save_model(model, nncf_ov_file_path) + fps = run_benchmark(nncf_ov_file_path, shape_input) + print(f"fps_{precision}: {fps} {sq_param}") + result[f"ov_fps_nncf_ov_{precision}_{sq_param}"] = fps + + latency = measure_time_ov(model, next(iter(val_loader))[0], num_iters=10_000) + print(f"latency_{precision}: {latency}") + result[f"ov_latency_nncf_ov_{precision}_{sq_param}"] = latency + + +def process_model(model_name: str): + + result = {"name": model_name} + model_cls, model_weights = MODELS_DICT[model_name] + output_dir = Path("models") / model_name + output_dir.mkdir(exist_ok=True) + ############################################################## + # Prepare dataset + ############################################################## + + val_dataset = datasets.ImageFolder(root=DATASET_IMAGENET, transform=model_weights.transforms()) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) + + ############################################################## + # Prepare original model + ############################################################## + + pt_model = model_cls(weights=model_weights) + pt_model = pt_model.eval() + example_input = next(iter(val_loader))[0] + shape_input = list(example_input.shape) + ############################################################## + # Process FP32 Model + ############################################################## + + fp32_pt_model = copy.deepcopy(pt_model) + + orig_infer_acc1 = model_weights.meta.get("_metrics", {}).get("ImageNet-1K", {}).get("acc@1") + print(f"fp32 model metric: {orig_infer_acc1}") + # orig_infer_acc1 = validate(fp32_pt_model, val_loader) + result["acc1_fp32_openvino"] = orig_infer_acc1 + + fp32_pt_model = torch.export.export(fp32_pt_model, (example_input,)) + ov_fp32_model = ov.convert_model(fp32_pt_model, example_input=example_input) + ov_fp32_file_path = None + ov_fp32_file_path = output_dir / "fp32.xml" + ov.save_model(ov_fp32_model, ov_fp32_file_path) + # result["fps_fp32_openvino"] = run_benchmark(ov_fp32_file_path, shape_input) + # print(f"fps_fp32_openvino {result['fps_fp32_openvino']}") + + del fp32_pt_model + ############################################################## + # Process Torch AO Quantize with SQ + ############################################################## + # torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # with torch.no_grad(): + # exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + # latency_fx = measure_time(torch.compile(exported_model), (example_input,)) + # print(f"latency: {latency_fx}") + ############################################################# + + ############################################################## + # Process PT Quantize + ############################################################## + fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF FX Quantize + ############################################################## + # nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF Quantize by PT + ############################################################## + # nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input) + + ############################################################## + # Process NNCF Quantize by OV + ############################################################## + # nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input) + + print(result) + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="torchvision model name", type=str, default="all") + parser.add_argument("--file_name", help="output csv file_name", type=str, default="result.csv") + + args = parser.parse_args() + + results_list = [] + if args.model == "all": + for model_name in MODELS_DICT: + print("---------------------------------------------------") + print(f"name: {model_name}") + results_list.append(process_model(model_name)) + else: + results_list.append(process_model(args.model)) + + df = pd.DataFrame(results_list) + print(df) + df.to_csv(args.file_name) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_compression/openvino/tiny_llama/main.py b/examples/llm_compression/openvino/tiny_llama/main.py index f2be54ce1aa..e5f3893f1ab 100644 --- a/examples/llm_compression/openvino/tiny_llama/main.py +++ b/examples/llm_compression/openvino/tiny_llama/main.py @@ -11,12 +11,12 @@ import time from functools import partial -import datasets import numpy as np import openvino as ov from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoTokenizer +import datasets import nncf diff --git a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py index b3fbce5722b..e34b09bc2f9 100644 --- a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py +++ b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py @@ -17,12 +17,12 @@ import numpy as np import openvino as ov -from datasets import load_dataset from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer from whowhatbench import Evaluator import nncf +from datasets import load_dataset from nncf.common.logging import nncf_logger DataItem = TypeVar("DataItem") diff --git a/examples/post_training_quantization/openvino/yolov8/main.py b/examples/post_training_quantization/openvino/yolov8/main.py index fd31a0c5fea..a660df332e6 100644 --- a/examples/post_training_quantization/openvino/yolov8/main.py +++ b/examples/post_training_quantization/openvino/yolov8/main.py @@ -8,36 +8,102 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + import re import subprocess +import time +from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import numpy as np import openvino as ov +import openvino.torch # noqa import torch +from torch._export import capture_pre_autograd_graph +from torch.export import Dim # noqa +from torch.fx.passes.graph_drawer import FxGraphDrawer from tqdm import tqdm from ultralytics.cfg import get_cfg -from ultralytics.data.converter import coco80_to_coco91_class from ultralytics.data.utils import check_det_dataset from ultralytics.engine.validator import BaseValidator as Validator from ultralytics.models.yolo import YOLO -from ultralytics.utils import DATASETS_DIR from ultralytics.utils import DEFAULT_CFG -from ultralytics.utils.metrics import ConfusionMatrix +from ultralytics.utils.torch_utils import de_parallel import nncf ROOT = Path(__file__).parent.resolve() -def validate( +def measure_time(model, example_inputs, num_iters=500): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def measure_time_ov(model, example_inputs, num_iters=1000): + ie = ov.Core() + compiled_model = ie.compile_model(model, "CPU") + infer_request = compiled_model.create_infer_request() + infer_request.infer(example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + infer_request.infer(example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def validate_fx( model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None ) -> Tuple[Dict, int, int]: - validator.seen = 0 - validator.jdict = [] - validator.stats = [] - validator.confusion_matrix = ConfusionMatrix(nc=validator.nc) + # validator.seen = 0 + # validator.jdict = [] + # validator.stats = [] + # validator.confusion_matrix = ConfusionMatrix(nc=validator.nc) + for batch_i, batch in enumerate(data_loader): + if num_samples is not None and batch_i == num_samples: + break + batch = validator.preprocess(batch) + preds = model(batch["img"]) + preds = validator.postprocess(preds) + validator.update_metrics(preds, batch) + stats = validator.get_stats() + return stats, validator.seen, validator.nt_per_class.sum() + + +def print_statistics_short(stats: np.ndarray) -> None: + mp, mr, map50, mean_ap = ( + stats["metrics/precision(B)"], + stats["metrics/recall(B)"], + stats["metrics/mAP50(B)"], + stats["metrics/mAP50-95(B)"], + ) + s = ("%20s" + "%12s" * 4) % ("Class", "Precision", "Recall", "mAP@.5", "mAP@.5:.95") + print(s) + pf = "%20s" + "%12.3g" * 4 # print format + print(pf % ("all", mp, mr, map50, mean_ap)) + + +def validate_ov( + model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None +) -> Tuple[Dict, int, int]: + # validator.seen = 0 + # validator.jdict = [] + # validator.stats = [] + # validator.confusion_matrix = ConfusionMatrix(nc=validator.nc) model.reshape({0: [1, 3, -1, -1]}) compiled_model = ov.compile_model(model) output_layer = compiled_model.output(0) @@ -65,21 +131,19 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) - print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) -def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]: - validator = model.smart_load("validator")(args) - validator.data = check_det_dataset(args.data) - dataset = validator.data["val"] - print(f"{dataset}") +def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: + # custom = {"rect": True, "batch": 1} # method defaults + # rect: false forces to resize all input pictures to one size + custom = {"rect": False, "batch": 1} # method defaults + args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right - data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128", 1) + validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) + stride = 32 # default stride + validator.stride = stride # used in get_dataloader() for padding + validator.data = check_det_dataset(data) + validator.init_metrics(de_parallel(model)) - validator = model.smart_load("validator")(args) - - validator.is_coco = True - validator.class_map = coco80_to_coco91_class() - validator.names = model.model.names - validator.metrics.names = validator.names - validator.nc = model.model.model[-1].nc + data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) return validator, data_loader @@ -104,7 +168,9 @@ def prepare_openvino_model(model: YOLO, model_name: str) -> Tuple[ov.Model, Path return ov.Core().read_model(ir_model_path), ir_model_path -def quantize(model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> ov.Model: +def quantize( + model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, original_model +) -> ov.Model: def transform_fn(data_item: Dict): """ Quantization transform function. Extracts and preprocess input data from dataloader @@ -136,44 +202,186 @@ def transform_fn(data_item: Dict): return quantized_model +NNCF_QUANTIZATION = False + + +def quantize_impl(exported_model, val_loader, validator): + def transform_fn(x): + batch = validator.preprocess(x) + return batch["img"] + + calibration_dataset = nncf.Dataset(val_loader, transform_fn) + dir_name = str(Path(__file__).parent) + if NNCF_QUANTIZATION: + converted_model = nncf.quantize( + exported_model, + calibration_dataset, + ignored_scope=nncf.IgnoredScope( + types=["mul", "sub", "sigmoid"], + subgraphs=[ + nncf.Subgraph( + inputs=["cat_13", "cat_14", "cat_15"], + outputs=["output"], + ) + ], + ), + ) + g = FxGraphDrawer(converted_model, "yolo_nncf_fx_int8") + g.get_dot_graph().write_svg(dir_name + "/yolo_nncf_fx_int8.svg") + + quantized_model = torch.compile(converted_model, backend="openvino") + return quantized_model + else: + from torch.ao.quantization.quantize_pt2e import convert_pt2e + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer + from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config + + quantizer = X86InductorQuantizer() + quantizer.set_global(get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + + for idx, batch in tqdm(enumerate(calibration_dataset.get_inference_data())): + if idx >= 300: + break + prepared_model(batch) + + converted_model = convert_pt2e(prepared_model) + + g = FxGraphDrawer(prepared_model, "yolo_torch_fx_int8") + g.get_dot_graph().write_svg(dir_name + "/yolo_torch_fx_int8.svg") + import torch._inductor.config as config + + config.cpp_wrapper = True + + quantized_model = torch.compile(converted_model) + return quantized_model + + +TORCH_FX = True +MODEL_NAME = "yolov8n" + + def main(): - MODEL_NAME = "yolov8n" model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") - args = get_cfg(cfg=DEFAULT_CFG) - args.data = "coco128.yaml" + # args = get_cfg(cfg=DEFAULT_CFG) + # args.data = "coco128.yaml" # Prepare validation dataset and helper - validator, data_loader = prepare_validation(model, args) + + validator, data_loader = prepare_validation(model, "coco128.yaml") # Convert to OpenVINO model + batch = next(iter(data_loader)) + batch = validator.preprocess(batch) + + if TORCH_FX: + fp_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator) + print("Floating-point Torch model validation results:") + print_statistics(fp_stats, total_images, total_objects) + + if NNCF_QUANTIZATION: + fp32_compiled_model = torch.compile(model.model, backend="openvino") + else: + fp32_compiled_model = torch.compile(model.model) + fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator) + print("FP32 FX model validation results:") + print_statistics(fp32_stats, total_images, total_objects) + + print("Start quantization...") + # Rebuild model to reset ultralitics cache + model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") + with torch.no_grad(): + model.model.eval() + model.model(batch["img"]) + # dynamic_shapes = ((None, None, Dim("H", min=1, max=29802), Dim("W", min=1, max=29802)),) + dynamic_shapes = ((None, None, None, None),) + exported_model = capture_pre_autograd_graph( + model.model, args=(batch["img"],), dynamic_shapes=dynamic_shapes + ) + quantized_model = quantize_impl(deepcopy(exported_model), data_loader, validator) + + int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator) + print("INT8 FX model validation results:") + print_statistics(int8_stats, total_images, total_objects) + + print("Start FX fp32 model benchmarking...") + fp32_latency = measure_time(fp32_compiled_model, (batch["img"],)) + print(f"fp32 FX latency: {fp32_latency}") + + print("Start FX int8 model benchmarking...") + int8_latency = measure_time(quantized_model, (batch["img"],)) + print(f"FX int8 latency: {int8_latency}") + print(f"Speed up: {fp32_latency / int8_latency}") + return + ov_model, ov_model_path = prepare_openvino_model(model, MODEL_NAME) # Quantize mode in OpenVINO representation - quantized_model = quantize(ov_model, data_loader, validator) + quantized_model = quantize(ov_model, data_loader, validator, model) quantized_model_path = Path(f"{ROOT}/{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml") ov.save_model(quantized_model, str(quantized_model_path), compress_to_fp16=False) + args = get_cfg(cfg=DEFAULT_CFG) + args.data = "coco128.yaml" # Validate FP32 model - fp_stats, total_images, total_objects = validate(ov_model, tqdm(data_loader), validator) + fp_stats, total_images, total_objects = validate_ov(ov_model, tqdm(data_loader), validator) print("Floating-point model validation results:") print_statistics(fp_stats, total_images, total_objects) # Validate quantized model - q_stats, total_images, total_objects = validate(quantized_model, tqdm(data_loader), validator) + q_stats, total_images, total_objects = validate_ov(quantized_model, tqdm(data_loader), validator) print("Quantized model validation results:") print_statistics(q_stats, total_images, total_objects) - # Benchmark performance of FP32 model - fp_model_perf = benchmark_performance(ov_model_path, args) - print(f"Floating-point model performance: {fp_model_perf} FPS") - - # Benchmark performance of quantized model - quantized_model_perf = benchmark_performance(quantized_model_path, args) - print(f"Quantized model performance: {quantized_model_perf} FPS") + fps = True + latency = True + fp_model_perf = -1 + quantized_model_perf = -1 + if fps: + # Benchmark performance of FP32 model + fp_model_perf = benchmark_performance(ov_model_path, args) + print(f"Floating-point model performance: {fp_model_perf} FPS") + + # Benchmark performance of quantized model + quantized_model_perf = benchmark_performance(quantized_model_path, args) + print(f"Quantized model performance: {quantized_model_perf} FPS") + if latency: + fp_model_latency = measure_time_ov(ov_model, batch["img"]) + print(f"FP32 OV model latency: {fp_model_latency}") + int8_model_latency = measure_time_ov(quantized_model, batch["img"]) + print(f"INT8 OV model latency: {int8_model_latency}") return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf +def main_export_not_strict(): + model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") + + # Prepare validation dataset and helper + validator, data_loader = prepare_validation(model, "coco128.yaml") + + batch = next(iter(data_loader)) + batch = validator.preprocess(batch) + + model.model(batch["img"]) + ex_model = torch.export.export(model.model, args=(batch["img"],), strict=False) + ex_model = capture_pre_autograd_graph(ex_model.module(), args=(batch["img"],)) + ex_model = torch.compile(ex_model) + + fp_stats, total_images, total_objects = validate_fx(ex_model, tqdm(data_loader), validator) + print("Floating-point ex strict=False") + print_statistics(fp_stats, total_images, total_objects) + + quantized_model = quantize_impl(deepcopy(ex_model), data_loader, validator) + int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator) + print("Int8 ex strict=False") + print_statistics(int8_stats, total_images, total_objects) + # No quantized were inserted, metrics are OK + + if __name__ == "__main__": + # main_export_not_strict() main() diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index 1b586f4a995..674e8d60291 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -19,6 +19,7 @@ import nncf from nncf.torch import disable_tracing +from torch._export import capture_pre_autograd_graph import openvino as ov import torch import torchvision @@ -27,7 +28,9 @@ from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models.detection.ssd import SSD from torchvision.models.detection.ssd import GeneralizedRCNNTransform +from torchvision.transforms.functional import pil_to_tensor from torchvision.models.detection.anchor_utils import DefaultBoxGenerator +from torch.export import Dim from nncf.common.logging.track_progress import track from functools import partial @@ -118,6 +121,7 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi metric = MeanAveragePrecision() with torch.no_grad(): for img, target in track(dataset, description="Validating"): + print(img.shape) prediction = model(img.to(device)[None])[0] for k in prediction: prediction[k] = prediction[k].to(torch.device("cpu")) @@ -135,16 +139,38 @@ def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> def main(): # Download and prepare the COCO128 dataset dataset_path = download_dataset() + # weights = torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights.DEFAULT + # transform = weights.transforms() weights_name = "SSD300_VGG16_Weights.DEFAULT" transform = torchvision.models.get_weight(weights_name).transforms() dataset = COCO128Dataset(dataset_path, lambda img, target: (transform(img), target)) # Get the pretrained ssd300_vgg16 model from torchvision.models model = torchvision.models.get_model("ssd300_vgg16", weights=weights_name) + # model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=weights) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) model.eval() + calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device)) + + inp = next(iter(calibration_dataset.get_inference_data())) + # dynamic_shapes = ((None, None, Dim("H"), Dim("W")),) + dynamic_shapes = ((None, None, None, None),) + # dynamic_shapes = ((Dim("batch"), None, None, None),) + _ = model(inp) + # r = validate(model, dataset, device) + # print(r) + compiled_model = capture_pre_autograd_graph(model, args=(inp,), dynamic_shapes=dynamic_shapes) + # compiled_model = torch.compile(model) + print("torch model") + r = validate(model, dataset, device) + print(f"mAP @ 0.5: {r:.3f}") + print("compiled model") + r = validate(compiled_model, dataset, device) + print(f"mAP @ 0.5: {r:.3f}") + return + # Disable NNCF tracing for some methods in order for the model to be properly traced by NNCF disable_tracing(GeneralizedRCNNTransform.normalize) disable_tracing(SSD.postprocess_detections) @@ -198,5 +224,109 @@ def main(): return fp32_map, int8_map, fp32_fps, int8_fps, fp32_model_size, int8_model_size +def validate_detr(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.device, processor): + model.to(device) + metric = MeanAveragePrecision() + min_h = 1000000 + max_h = 0 + min_w = 1000000 + max_w = 0 + with torch.no_grad(): + for img, target in track(dataset, description="Validating"): + + inputs = pil_to_tensor(img) + if inputs.shape[0] == 1: + inputs = torch.cat([inputs] * 3) + inputs = inputs[None] + + inputs = processor(images=inputs, return_tensors="pt") + min_h = min(min_h, inputs["pixel_values"].shape[2]) + max_h = max(max_h, inputs["pixel_values"].shape[2]) + min_w = min(min_w, inputs["pixel_values"].shape[3]) + max_w = max(max_w, inputs["pixel_values"].shape[3]) + + output = model(**inputs) + target_sizes = torch.tensor([img.size[::-1]]) + prediction = processor.post_process_object_detection(output, target_sizes=target_sizes, threshold=0.9)[0] + for k in prediction: + prediction[k] = prediction[k].to(torch.device("cpu")) + metric.update([prediction], [target]) + computed_metrics = metric.compute() + print(min_h, max_h, min_w, max_w) + return computed_metrics["map_50"] + + +def get_dert_inputs(processor, dataset): + img = next(iter(dataset))[0] + inputs = pil_to_tensor(img) + inputs = inputs[None] + return processor(images=inputs, return_tensors="pt") + + +def get_image(): + from PIL import Image + import requests + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + return image + + +def main_detr(): + from transformers import DetrImageProcessor, DetrForObjectDetection # noqa + from transformers import AutoImageProcessor, AutoModelForObjectDetection, ConditionalDetrForObjectDetection # noqa + from transformers import OwlViTProcessor, OwlViTForObjectDetection # noqa + import torch + + device = torch.device("cpu") + # you can specify the revision tag if you don't want the timm dependency + # processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") + # model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") + processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50") + model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50") + # processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + # model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") + model.eval() + + dataset_path = download_dataset() + dataset = COCO128Dataset(dataset_path, lambda img, target: (img, target)) + + h, w = Dim("H", min=454, max=1333), Dim("W", min=748, max=1333) + dynamic_shapes = {"pixel_values": {2: h, 3: w}, "pixel_mask": {2: h, 3: w}} + dynamic_shapes = ((None, None, h, w), (None, h, w)) + ex_inputs = get_dert_inputs(processor, dataset) + # captured_model = capture_pre_autograd_graph(model, args=(), kwargs=ex_inputs, dynamic_shapes=dynamic_shapes) + # captured_model = capture_pre_autograd_graph(model, args=(tuple(ex_inputs.values()),), + # dynamic_shapes=dynamic_shapes) + # captured_model = capture_pre_autograd_graph(model, args=tuple(ex_inputs.values())) + captured_model = capture_pre_autograd_graph(model, args=tuple(ex_inputs.values()), dynamic_shapes=dynamic_shapes) + # captured_model = capture_pre_autograd_graph(model,args=(), kwargs=ex_inputs) + + # compiled_model = torch.compile(model, dynamic=True) + # r = validate_detr(compiled_model, dataset, device, processor) + r = validate_detr(captured_model, dataset, device, processor) + print(f"mAP @ 0.5: {r:.3f}") + r = validate_detr(model, dataset, device, processor) + print(f"mAP @ 0.5: {r:.3f}") + + outputs = model(**ex_inputs) + + # convert outputs (bounding boxes and class logits) to COCO API + # let's only keep detections with score > 0.9 + image = get_image() + processor(images=image, return_tensors="pt") + target_sizes = torch.tensor([image.size[::-1]]) + results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] + + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + box = [round(i, 2) for i in box.tolist()] + print( + f"Detected {model.config.id2label[label.item()]} with confidence " + f"{round(score.item(), 3)} at location {box}" + ) + + if __name__ == "__main__": - main() + # main() + main_detr() diff --git a/examples/quantization_aware_training/torch/anomalib/main.py b/examples/quantization_aware_training/torch/anomalib/main.py index 5ad4f348db6..7e30a7ec7df 100644 --- a/examples/quantization_aware_training/torch/anomalib/main.py +++ b/examples/quantization_aware_training/torch/anomalib/main.py @@ -26,6 +26,7 @@ from anomalib.deploy import ExportType from anomalib.engine import Engine from anomalib.models import Stfpm +from torch._export import capture_pre_autograd_graph import nncf @@ -124,6 +125,10 @@ def transform_fn(data_item): # Quantize the inference model using Post-Training Quantization inference_model = model.model + + example_input = torch.ones((1, 3, 255, 255)) + with torch.no_grad(): + capture_pre_autograd_graph(inference_model, example_input) quantized_inference_model = nncf.quantize(model=inference_model, calibration_dataset=calibration_dataset) # Deepcopy the original model and set the quantized inference model diff --git a/examples/quantization_aware_training/torch/resnet18/main.py b/examples/quantization_aware_training/torch/resnet18/main.py index 7ab3b7af14a..026bd44acb8 100644 --- a/examples/quantization_aware_training/torch/resnet18/main.py +++ b/examples/quantization_aware_training/torch/resnet18/main.py @@ -10,15 +10,22 @@ # limitations under the License. import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + + import re import subprocess +import time import warnings from copy import deepcopy from pathlib import Path from typing import List, Tuple import openvino as ov +import openvino.torch # noqa import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn import torch.nn.parallel import torch.optim @@ -28,6 +35,11 @@ import torchvision.models as models import torchvision.transforms as transforms from fastdownload import FastDownload +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer from torch.jit import TracerWarning import nncf @@ -54,6 +66,18 @@ DATASET_PATH = "~/.cache/nncf/datasets" +def measure_time(model, example_inputs, num_iters): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + def download_dataset() -> Path: downloader = FastDownload(base=DATASET_PATH, archive="downloaded", data="extracted") return downloader.get(DATASET_URL) @@ -102,7 +126,7 @@ def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, de top1_sum = 0.0 # Switch to evaluate mode. - model.eval() + # model.eval() with torch.no_grad(): for images, target in track(val_loader, total=len(val_loader), description="Validation:"): @@ -230,7 +254,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb") -> float: def main(): torch.manual_seed(0) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") print(f"Using {device} device") ############################################################################### @@ -253,11 +277,66 @@ def transform_fn(data_item): # Step 2: Quantize model print(os.linesep + "[Step 2] Quantize model") - quantized_model = nncf.quantize(model, quantization_dataset) - acc1_int8_init = validate(val_loader, quantized_model, device) + with torch.no_grad(): + example_inputs = (torch.ones((1, 3, IMAGE_SIZE, IMAGE_SIZE)),) + exported_model = capture_pre_autograd_graph(model.eval(), example_inputs) + + NNCF_TORCH_FX = False + + if NNCF_TORCH_FX: + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + from itertools import islice + + from tqdm import tqdm + for data in tqdm(islice(quantization_dataset.get_inference_data(), 300)): + prepared_model(data) + quantized_model = convert_pt2e(prepared_model) + + g = FxGraphDrawer(quantized_model, "acc_resnet18_int8_native") + g.get_dot_graph().write_svg("acc_resnet18_int8_native.svg") + else: + quantized_model = nncf.quantize(exported_model, quantization_dataset) + g = FxGraphDrawer(quantized_model, "acc_resnet18_int8_nncf") + g.get_dot_graph().write_svg("acc_resnet18_int8_nncf.svg") + + # quantized_model = torch.compile(quantized_model) + # acc1_int8_init = validate(val_loader, quantized_model, device) + acc1_int8_init = validate(val_loader, torch.compile(quantized_model), device) print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}") + num_iters = 100 + + print("original model execution time: ", measure_time(model, example_inputs, num_iters)) + native_optimized_model_fp32 = torch.compile(exported_model) + print( + "Torch Inductor FP32 model execution time: ", + measure_time(native_optimized_model_fp32, example_inputs, num_iters), + ) + native_optimized_model_int8 = torch.compile(quantized_model) + print( + "Torch Inductor INT8 model execution time: ", + measure_time(native_optimized_model_int8, example_inputs, num_iters), + ) + + ov_optimized_model_fp32 = torch.compile(exported_model, backend="openvino") + print( + "Torch.compile OpenVINO FP32 model execution time: ", + measure_time(ov_optimized_model_fp32, example_inputs, num_iters), + ) + + ov_optimized_model_int8 = torch.compile( + quantized_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"} + ) + print( + "Torch.compile OpenVINO INT8 model execution time: ", + measure_time(ov_optimized_model_int8, example_inputs, num_iters), + ) + + return ############################################################################### # Step 3: Fine tune quantized model print(os.linesep + "[Step 3] Fine tune quantized model") diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 6616f9dbe3a..d5d13605a07 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph: if model_backend == BackendType.OPENVINO: from nncf.openvino.graph.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: return model.nncf.get_graph() @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: from nncf.torch.model_transformer import PTModelTransformer return PTModelTransformer(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + return FXModelTransformer(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value) ) @@ -99,6 +107,10 @@ def create(model: TModel) -> Engine: from nncf.torch.engine import PTEngine return PTEngine(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.engine import FXEngine + + return FXEngine(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific engine because {} is not supported!".format(model_backend.value) ) @@ -151,6 +163,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: from nncf.torch.statistics.aggregator import PTStatisticsAggregator return PTStatisticsAggregator(dataset) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator + + return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( "Cannot create backend-specific statistics aggregator because {} is not supported!".format( model_backend.value diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index 824a3b4a4c5..eae1546aad5 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -48,6 +48,11 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam if backend == BackendType.TORCH: from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS + registry = PT_HW_FUSED_PATTERNS.registry_dict + return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS + registry = PT_HW_FUSED_PATTERNS.registry_dict return registry raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.") @@ -76,6 +81,11 @@ def _get_backend_ignored_patterns_map( if backend == BackendType.TORCH: from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS + registry = PT_IGNORED_PATTERNS.registry_dict + return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS + registry = PT_IGNORED_PATTERNS.registry_dict return registry raise ValueError(f"Ignored patterns not implemented for {backend} backend.") diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index 9dcd6a57d71..e5de38e5ca4 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -20,6 +20,7 @@ class BackendType(Enum): TORCH = "Torch" + TORCH_FX = "TorchFX" TENSORFLOW = "Tensorflow" ONNX = "ONNX" OPENVINO = "OpenVINO" @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]: """ frameworks = [ ("torch", BackendType.TORCH), + ("torch.fx", BackendType.TORCH_FX), ("tensorflow", BackendType.TENSORFLOW), ("onnx", BackendType.ONNX), ("openvino.runtime", BackendType.OPENVINO), @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]: def is_torch_model(model: TModel) -> bool: """ - Returns True if the model is an instance of torch.nn.Module, otherwise False. + Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False. :param model: A target model. - :return: True if the model is an instance of torch.nn.Module, otherwise False. + :return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. """ import torch + import torch.fx - return isinstance(model, torch.nn.Module) + return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + + +def is_torch_fx_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of torch.fx.GraphModule, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of torch.fx.GraphModule, otherwise False. + """ + import torch.fx + + return isinstance(model, torch.fx.GraphModule) def is_tensorflow_model(model: TModel) -> bool: @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType: """ available_backends = get_available_backends() + if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model): + return BackendType.TORCH_FX + if BackendType.TORCH in available_backends and is_torch_model(model): return BackendType.TORCH diff --git a/nncf/experimental/torch_fx/__init__.py b/nncf/experimental/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/engine.py b/nncf/experimental/torch_fx/engine.py new file mode 100644 index 00000000000..5f9dc2ac221 --- /dev/null +++ b/nncf/experimental/torch_fx/engine.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Tuple, Union + +import torch +from torch import nn + +from nncf.common.engine import Engine + + +class FXEngine(Engine): + """ + Engine for the Pytorch FX backend. + """ + + def __init__(self, model: nn.Module): + """ + Constructor. + + :param model: Pytorch module to infer. + """ + + self._model = model + + def infer( + self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, Dict[str, Any]]: + """ + Runs Torch model on the provided input. + + :param input_data: Inputs for the model. + :return: Model outputs. + """ + + if isinstance(input_data, dict): + return self._model(**input_data) + if isinstance(input_data, tuple): + return self._model(*input_data) + return self._model(input_data) diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py new file mode 100644 index 00000000000..48b3cf0c1f1 --- /dev/null +++ b/nncf/experimental/torch_fx/model_transformer.py @@ -0,0 +1,183 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +# from functools import partial +from typing import Callable, List, Union + +import torch +import torch.fx +from torch.fx.passes.split_utils import split_by_tags + +from nncf.common.graph.model_transformer import ModelTransformer +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.commands import TransformationType +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout + + +class FXModuleInsertionCommand(Command): + def __init__( + self, + target_points: List[PTTargetPoint], + module_to_insert: torch.nn.Module, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.target_points = target_points + self.module_to_insert = module_to_insert + self.priority = priority + + +class FXApplyTransformationCommand(Command): + def __init__( + self, + transformation_fn: Callable[[torch.fx.GraphModule], None], + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.tranformation_fn = transformation_fn + self.priority = priority + + +class FXModelTransformer(ModelTransformer): + """ + Applies transformations upon Torch FX model. + """ + + # TODO: manage priorities of transformations + + def __init__(self, model: torch.fx.GraphModule): + super().__init__(model) + + self._command_transformation_ordered_pairs = [ + # TODO: Move the module insertion command to a transformation + (FXApplyTransformationCommand, self._apply_transformation), + (FXModuleInsertionCommand, self._apply_module_insertion), + (PTModelExtractionCommand, self._apply_model_extraction), + ] + + def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + transformations = transformation_layout.transformations + aggregated_transformations = defaultdict(list) + for transformation in transformations: + aggregated_transformations[transformation.__class__].append(transformation) + + model = self._model + for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: + transformations = aggregated_transformations[transformation_cls] + if transformations: + model = transformation_fn(model, transformations) + + # Do not eliminate dead code as + # the dead code is coputing statistics :) + # model.graph.eliminate_dead_code() + model.recompile() + return model + + @staticmethod + def _apply_model_extraction( + model: torch.fx.GraphModule, + transformations: List[PTModelExtractionCommand], + ) -> torch.fx.GraphModule: + transformation = transformations[-1] + assert len(transformation.input_node_names) == 1 + assert transformation.input_node_names == transformation.output_node_names + node_name = transformation.input_node_names[0] + + tags = ["before", "extracted", "after"] + i = 0 + for node in model.graph.nodes: + if node.name == node_name: + node.tag = tags[1] + weights = [node.all_input_nodes[1]] + while weights: + w_node = weights.pop() + assert w_node.tag in tags[0:2] + w_node.tag = tags[1] + weights.extend(w_node.all_input_nodes) + i = 2 + continue + node.tag = tags[i] + + splitted_gm = split_by_tags(model, tags) + return splitted_gm.extracted + + @staticmethod + def _apply_module_insertion( + model: torch.fx.GraphModule, + transformations: List[FXModuleInsertionCommand], + ) -> torch.fx.GraphModule: + """ + Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts + a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. + :return: A modified torch.fx.GraphModule. + """ + for transformation in transformations: + # Set fn to the model as an attribute + module_to_insert = transformation.module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) + for tp in transformation.target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + # Insert call_module nodes to the model + for target_point in transformation.target_points: + FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) + return model + + @staticmethod + def get_graph_node_by_name(graph, name): + for node in graph.nodes: + if node.name == name: + return node + raise RuntimeError(f"Node with name {name} is not found") + + @staticmethod + def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): + target_type = target_point.target_type + target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass + else: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + @staticmethod + def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): + target_node = FXModelTransformer._get_target_node(graph, target_point) + with graph.inserting_after(target_node): + graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") + + @staticmethod + def _apply_transformation( + model: torch.fx.GraphModule, + transformations: List[FXApplyTransformationCommand], + ) -> torch.fx.GraphModule: + for transformation in transformations: + transformation.tranformation_fn(model) + return model diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py new file mode 100644 index 00000000000..9990ee3bf2f --- /dev/null +++ b/nncf/experimental/torch_fx/nncf_graph_builder.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +from typing import Tuple + +import torch.fx +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import UnknownMetatype +from nncf.common.logging import nncf_logger +from nncf.experimental.torch_fx.transformations import separate_conv_and_bias +from nncf.experimental.torch_fx.transformations import separate_linear_and_bias +from nncf.experimental.torch_fx.transformations import view_to_reshape +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES + + +class GraphConverter: + """ + Builds the NNCFGraph from an OpenVINO model. + """ + + @staticmethod + def _get_leaf_node(module: torch.nn.Module, node: torch.fx.Node) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError(str(py_obj) + " does not have attribute " + atom + "!") + py_obj = getattr(py_obj, atom) + return py_obj + + @staticmethod + def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + if node.op == "placeholder": + node_type = "input" + node_metatype = om.PTInputNoopMetatype + elif node.op == "output": + node_type = "output" + node_metatype = om.PTOutputNoopMetatype + elif node.op == "get_attr": + node_type = "get_attr" + node_metatype = om.PTConstNoopMetatype + elif node.op in ("call_function",): + if hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + elif node.target.__name__ == "getitem": + node_type = "__getitem__" + else: + # TODO: get correct nodes types from this nodes as well + node_type = str(node.target) + node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) + else: + node_type = node.op + node_metatype = UnknownMetatype + if node_metatype is UnknownMetatype: + nncf_logger.info(f"Unknown metatype for node: {node}") + return node_type, node_metatype + + @staticmethod + def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: + """ + Creates NNCFGraph from GraphModule. + All nodes from model which have valid metatype are added to NNCFGraph. + Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. + + :param model: torch fx GraphModule. + :return: NNCFGraph. + """ + + _fuse_conv_bn_(model) + # BN fuses to conv bias, conv+bias joined op + # needs to be splited for nncf + separate_linear_and_bias(model) + separate_conv_and_bias(model) + view_to_reshape(model) + + nncf_graph = PTNNCFGraph() + + for source_node in model.graph.nodes: + + node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) + + nncf_node = nncf_graph.add_nncf_node( + node_name=source_node.name, + node_type=node_type, + node_metatype=node_metatype, # layer_attributes, + ) + + def get_module_params_or_buffers(): + for pname, ptensor in chain(leaf_module.named_parameters(), leaf_module.named_buffers()): + pname1 = source_node.name + "." + pname + nncf_param_node = nncf_graph.add_nncf_node( + pname1, + "parameter" if isinstance(ptensor, torch.nn.Parameter) else "buffer", + om.PTConstNoopMetatype, + ) + # TODO: Use valid tensor_shape, input_port_id, output_port_id + nncf_graph.add_edge_between_nncf_nodes( + nncf_param_node, nncf_node, tensor_shape=[1, 1, 1, 1], input_port_id=0, output_port_id=0 + ) + + if source_node.op == "call_module": + leaf_module = GraphConverter._get_leaf_node(model, source_node) + + if not isinstance(leaf_module, torch.fx.GraphModule): + get_module_params_or_buffers() + + for source_node in model.graph.nodes: + + source_nncf_node = nncf_graph.get_node_by_name(source_node.name) + for idx, dist_node in enumerate(source_node.users): + dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id + input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( + model, source_node, source_nncf_node, dist_node, idx + ) + + nncf_graph.add_edge_between_nncf_nodes( + source_nncf_node.node_id, + dist_node_id, + tensor_shape=tensor_shape, + input_port_id=input_port_id, + output_port_id=output_port_id, + dtype=Dtype.FLOAT, + ) + + return nncf_graph + + @staticmethod + def get_edge_params( + model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int + ): + output_port_id = 0 + if source_node.op in ("get_attr",): + tensor_shape = tuple(getattr(model, source_node.target).shape) + elif "val" in source_node.meta: + if source_nncf_node.metatype is om.PTBatchNormMetatype: + tensor = source_node.meta["val"][0] + elif source_nncf_node.metatype is om.PTSplitMetatype: + tensor = source_node.meta["val"][output_idx] + # Assume every split outputs corresponds to an unique output_port_id + output_port_id = output_idx + else: + tensor = source_node.meta["val"] + tensor_shape = tuple(tensor.shape) + else: + nncf_logger.info( + f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead." + ) + tensor_shape = [1, 1, 1, 1] + + input_port_id = dist_node.all_input_nodes.index(source_node) + return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch_fx/quantization/__init__.py b/nncf/experimental/torch_fx/quantization/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py new file mode 100644 index 00000000000..0f40800fb49 --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/quantize_model.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.fx +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import PassManager + +import nncf +from nncf.common.factory import NNCFGraphFactory +from nncf.common.quantization.structs import QuantizationPreset +from nncf.common.quantization.structs import QuantizationScheme +from nncf.data import Dataset +from nncf.experimental.torch_fx.transformations import merge_conv_and_bias +from nncf.parameters import ModelType +from nncf.parameters import QuantizationMode +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.advanced_parameters import QuantizationParameters +from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.scopes import IgnoredScope + +DEFAULT_RANGE_TYPE = "mean_min_max" + + +def quantize_impl( + model: torch.fx.GraphModule, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.nn.Module: + """ + Implementation of the `quantize()` method for the Torch FX backend. + """ + if fast_bias_correction is False: + raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") + if target_device == TargetDevice.CPU_SPR: + raise nncf.InternalError("target_device == CPU_SPR is not supported") + if mode is not None: + raise ValueError(f"mode={mode} is not supported") + + original_graph_meta = model.meta + + copied_model = deepcopy(model) + + if advanced_parameters is None: + advanced_parameters = AdvancedQuantizationParameters() + # torch.fx supports only assymetric activations quantization + # force to use only this type of quantization + activations_quantization_params = advanced_parameters.activations_quantization_params + if activations_quantization_params is None: + activations_quantization_params = QuantizationParameters() + + activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC + advanced_parameters.activations_quantization_params = activations_quantization_params + + quantization_algorithm = PostTrainingQuantization( + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + nncf_graph = NNCFGraphFactory.create(copied_model) + quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + merge_conv_and_bias(quantized_model) + + # Magic. Without this call compiled model + # is not preformant + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + quantized_model = _fold_conv_bn_qat(quantized_model) + pm = PassManager([DuplicateDQPass()]) + + quantized_model = pm(quantized_model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + quantized_model = pm(quantized_model).graph_module + + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) + + return quantized_model diff --git a/nncf/experimental/torch_fx/statistics/__init__.py b/nncf/experimental/torch_fx/statistics/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/statistics/aggregator.py b/nncf/experimental/torch_fx/statistics/aggregator.py new file mode 100644 index 00000000000..1497b41aa44 --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/aggregator.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nncf.common.factory import TModel +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer +from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.return_types import maybe_get_values_from_torch_return_type +from nncf.torch.tensor import PTNNCFTensor + + +class TensorCollectorModule(torch.nn.Module): + """ + torch.nn.Module which calls given collector in forward + """ + + def __init__(self, collector: TensorCollector): + super().__init__() + self._collector = collector + + def forward(self, x: torch.Tensor): + """ + Register inputs hook function. + + :parameter x: tensor to register in hook. + :return: tensor to register in hook. + """ + x_unwrapped = maybe_get_values_from_torch_return_type(x) + self._collector.register_input_for_all_reducers(PTNNCFTensor(x_unwrapped)) + return x + + +class FXStatisticsAggregator(StatisticsAggregator): + HOOKS_GROUP_NAME = "statistics_hooks" + + def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: + with torch.no_grad(): + super().collect_statistics(model, graph) + # All statistics are collected as a dead code, + # so eliminate dead core removed statistcs collector + # from the target model. No additional code required + # for that, horay! + model.graph.eliminate_dead_code() + model.recompile() + + def _register_statistics( + self, outputs: Dict[str, PTNNCFTensor], statistic_points: StatisticPointsContainer + ) -> None: + return + + def _get_transformation_layout_extra_outputs( + self, statistic_points: StatisticPointsContainer + ) -> TransformationLayout: + transformation_layout = TransformationLayout() + transformation_commands = [] + + for _statistic_points in statistic_points.values(): + for _statistic_point in _statistic_points: + for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): + for collector in collectors: + transformation_commands.append( + FXModuleInsertionCommand( + [_statistic_point.target_point], + TensorCollectorModule(collector), + TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, + ) + ) + + for transformation_command in transformation_commands: + transformation_layout.register(transformation_command) + + return transformation_layout + + @staticmethod + def _get_merged_statistic_points( + statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph + ) -> StatisticPointsContainer: + # TODO: mirgate to experimental statistic collector and use common merging algorithm + return statistic_points + + @staticmethod + def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, PTNNCFTensor]: + return outputs diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py new file mode 100644 index 00000000000..d572c06b120 --- /dev/null +++ b/nncf/experimental/torch_fx/transformations.py @@ -0,0 +1,350 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional + +import torch +import torch.fx +from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node +from torch.ao.quantization.pt2e.utils import _is_conv +from torch.quantization.fake_quantize import FakeQuantize + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch_fx.model_transformer import FXModelTransformer +from nncf.torch.graph.transformations.commands import PTTargetPoint + + +def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) + graph = model.graph + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + with graph.inserting_after(target_node): + fq_node = graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" + ) + for user in list(target_node.users): + if user is fq_node: + continue + user.replace_input_with(target_node, fq_node) + + return fake_quantize_insertion_transformation + + +def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): + def bias_update_transformation(model: torch.fx.GraphModule): + graph = model.graph + target_node_name = node.node_name + graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name) + bias_node = next(iter(graph_node.users)) + with graph.inserting_before(bias_node): + new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) + args = list(bias_node.args) + args[1] = new_constant + bias_node.args = tuple(args) + graph.eliminate_dead_code() + + return bias_update_transformation + + +def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def qdq_insertion_tranformation(model: torch.fx.GraphModule): + if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: + raise RuntimeError + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + insert_one_qdq(model, target_node, quantizer, target_point) + + return qdq_insertion_tranformation + + +def insert_one_qdq( + model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize, target_point: PTTargetPoint +): + # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 + if quantizer.is_per_channel: + qparams = { + "_scale_": quantizer.scale, + "_zero_point_": quantizer.zero_point, + "_axis_": quantizer.ch_axis, + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + qparams = { + "_scale_": float(quantizer.scale), + "_zero_point_": int(quantizer.zero_point), + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + # 2. replace activation_post_process node with quantize and dequantize + graph = model.graph + # TODO: use metatype to get correct input_port_id + # Do not quantize already quantized nodes + # inserting_before handle only order in the graph generated code. + # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes + with graph.inserting_before(target_node): + quantize_op_inputs = [target_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store + # them as literals in the graph. + quantize_op_inputs.append(value_or_node) + with graph.inserting_after(target_node): + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + user_dq_nodes = [] + with graph.inserting_after(quantized_node): + for user in target_node.users: + if user is quantized_node: + continue + user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) + + for user, dq_node in user_dq_nodes: + user.replace_input_with(target_node, dq_node) + + +def _set_module_to_the_graph_module( + model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> str: + """ + Sets given module to the given torch.fx.GraphModule with unique name. + """ + module_to_insert = module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + return module_name_in_model + + +def _is_linear(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] + + +def separate_linear_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined linear+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_linear(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + linear_node = n + linear_bias_node = linear_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) + args = list(n.args) + args[2] = None + linear_node.args = tuple(args) + with model.graph.inserting_after(linear_node): + new_linear_bias_node = create_getattr_from_value( + model, + model.graph, + linear_bias_node.name + "_", + conv_bias_value, + ) + with model.graph.inserting_after(new_linear_bias_node): + add_node = model.graph.create_node( + "call_function", add_node_target, (linear_node, new_linear_bias_node), {} + ) + for user in list(linear_node.users): + if user is add_node: + continue + user.replace_input_with(linear_node, add_node) + if "val" in linear_node.meta: + add_node.meta["val"] = linear_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def view_to_reshape(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): + continue + with model.graph.inserting_after(n): + reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) + reshape.meta = n.meta + + for user in list(n.users): + user.replace_input_with(n, reshape) + + model.graph.eliminate_dead_code() + model.recompile() + + +def separate_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + conv_node = n + dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape) + conv_bias_node = conv_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model) + args = list(n.args) + args[2] = None + conv_node.args = tuple(args) + with model.graph.inserting_after(conv_node): + new_conv_bias_node = create_getattr_from_value( + model, + model.graph, + conv_bias_node.name + "_", + conv_bias_value.reshape( + ( + 1, + -1, + ) + + (1,) * (dims - 2) + ), + ) + with model.graph.inserting_after(new_conv_bias_node): + add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) + for user in list(conv_node.users): + if user is add_node: + continue + user.replace_input_with(conv_node, add_node) + + if "val" in conv_node.meta: + add_node.meta["val"] = conv_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def merge_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_targets = (torch.ops.aten.add_.Tensor,) + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) > 2 and n.args[2] is not None: + continue + bias_node = next(iter(n.users)) + if len(n.users) > 1 or bias_node.target not in add_node_targets: + continue + conv_node = n + const_node = None + for node in bias_node.all_input_nodes: + if node is not conv_node: + const_node = node + break + assert const_node is not None + bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() + with model.graph.inserting_before(conv_node): + new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) + args = list(conv_node.args) + args[2] = new_bias_node + conv_node.args = tuple(args) + for user in list(bias_node.users): + user.replace_input_with(bias_node, conv_node) + + model.graph.eliminate_dead_code() + model.recompile() + + +def _is_scaled_dot_product_attention(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.scaled_dot_product_attention.default] + + +def _unfold_sdp(model: torch.fx.GraphModule, node: torch.fx.Node): + transpose_target = torch.ops.aten.transpose.int + matmul_target = torch.ops.aten.matmul.default + mul_target = torch.ops.aten.multiply.Scalar + softmax_target = torch.ops.aten.softmax.int + + query, key, value = node.args + q, k, v = (n.meta["val"] for n in node.args) + n = query.meta["val"].shape[-1] + scale_factor = 1 / math.sqrt(n) + + with model.graph.inserting_before(node): + k_transposed = model.graph.create_node("call_function", transpose_target, (key, -2, -1), {}) + k = k.transpose(-2, -1) + k_transposed.meta["val"] = torch.clone(k) + + sa = model.graph.create_node("call_function", matmul_target, (query, k_transposed), {}) + attn_value = q @ k + sa.meta["val"] = torch.clone(attn_value) + + sa_scaled = model.graph.create_node("call_function", mul_target, (sa, float(scale_factor)), {}) + sa_scaled.meta["val"] = torch.clone(attn_value) + + softmax = model.graph.create_node("call_function", softmax_target, (sa_scaled, -1), {}) + softmax.meta["val"] = torch.clone(attn_value) + + result = model.graph.create_node("call_function", matmul_target, (softmax, value), {}) + r = attn_value @ v + result.meta["val"] = torch.clone(r) + + for user in list(node.users): + user.replace_input_with(node, result) + model.graph.eliminate_dead_code() + + +@staticmethod +def unfold_scaled_dot_product_attention(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not _is_scaled_dot_product_attention(n): + continue + args = n.args + if len(args) > 3: + raise NotImplementedError( + f"Unfolding of scaled dot product attention node {n}" " with more than 3 inputs is not implemented yet" + ) + _unfold_sdp(model, n) + model.graph.eliminate_dead_code() + model.recompile() diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index f29d41a3e7c..6b3115a46df 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -93,7 +93,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _set_backend_entity(self, model: TModel) -> None: """ @@ -116,6 +116,12 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend self._backend_entity = PTFastBiasCorrectionAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import ( + FXFastBiasCorrectionAlgoBackend, + ) + + self._backend_entity = FXFastBiasCorrectionAlgoBackend() else: raise nncf.UnsupportedBackendError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py new file mode 100644 index 00000000000..c42fc5f3c7b --- /dev/null +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.fx +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor import Tensor +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector + + +class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_bias_correction_command( + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph + ) -> FXApplyTransformationCommand: + return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data)) + + @staticmethod + def model_extraction_command( + input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]] + ) -> PTModelExtractionCommand: + return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]]) + + @staticmethod + def mean_statistic_collector( + channel_axis: int, + inplace: bool, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ) -> TensorCollector: + return get_mean_statistic_collector(num_samples, channel_axis, window_size) + + @staticmethod + def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: + # Pytorch does not have name for extracted node + return None, None + + @staticmethod + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) + for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): + index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) + blob[index] = data[j].data + return blob + + @staticmethod + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: + # TODO: make a node_name_vs_node map to speed up the process + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + bias_node = nncf_graph.get_next_nodes(node)[0] + graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) + return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) + + @staticmethod + def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: + return 0, 0 + + @staticmethod + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) + + @staticmethod + def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + weight_node = nncf_graph.get_previous_nodes(node)[1] + return weight_node.node_type == "dequantize_per_channel" + + @staticmethod + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + # Assumes that all biases were unfused + if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype): + next_nodes = nncf_graph.get_next_nodes(node) + if len(next_nodes) != 1: + return False + return next_nodes[0].metatype in (om.PTAddMetatype,) + + @staticmethod + def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: + return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 8ceeaa3902d..2472fe5196e 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -318,7 +318,7 @@ def _reset_cache(self): @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _get_quantizer_constraints( self, @@ -371,6 +371,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend self._backend_entity = OVMinMaxAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend + + self._backend_entity = FXMinMaxAlgoBackend() elif model_backend == BackendType.TORCH: from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py new file mode 100644 index 00000000000..053f6cb923f --- /dev/null +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -0,0 +1,363 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Set, Tuple + +import torch + +import nncf +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.hardware.config import HWConfig +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder +from nncf.parameters import ModelType +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import StatisticsType +from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend +from nncf.quantization.fake_quantize import FakeConvertParameters +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import get_scale_shape +from nncf.torch.quantization.strip import convert_to_torch_fakequantizer +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic + + +class FXMinMaxAlgoBackend(MinMaxAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @property + def mat_mul_metatypes(self) -> List[OperatorMetatype]: + return [om.PTLinearMetatype, om.PTMatMulMetatype] + + @property + def post_processing_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def shapeof_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def dropout_metatypes(self) -> List[OperatorMetatype]: + return [om.PTDropoutMetatype] + + @property + def read_variable_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def conv_metatypes(self) -> List[OperatorMetatype]: + return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype] + + @property + def overflow_fix_metatypes(self) -> List[OperatorMetatype]: + return [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTLinearMetatype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + ] + + @property + def add_metatypes(self) -> List[OperatorMetatype]: + return [om.PTAddMetatype] + + @property + def group_conv_metatypes(self) -> List[OperatorMetatype]: + return self.conv_metatypes + + @property + def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: + return {om.PTCatMetatype: self.overflow_fix_metatypes} + + @property + def hw_config(self) -> HWConfig: + return PTHWConfig + + @property + def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: + return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + + @staticmethod + def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: + return nncf_graph.get_input_nodes() + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_convert_insertion_command( + target_point: PTTargetPoint, + parameters: FakeConvertParameters, + ) -> TransformationCommand: + raise nncf.InternalError("FakeConvert insertion not implemented in PyTorch backend!") + + @staticmethod + def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTensorStatistic: + max_values, min_values = [], [] + for statistic in statistics: + max_values.append(statistic.max_values.flatten()) + min_values.append(statistic.min_values.flatten()) + max_values = torch.amax(torch.stack(max_values), dim=0) + min_values = torch.amin(torch.stack(min_values), dim=0) + return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values) + + @staticmethod + def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: + return nncf_graph.get_input_shape_for_insertion_point(target_point) + + @staticmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int]: + # TODO: support transpose conv and other cases + return (0,) + + @staticmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorCollector: + collector = TensorCollector(PTMinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [PTMinMaxTensorStatistic.MIN_STAT, PTMinMaxTensorStatistic.MAX_STAT], + ): + if params.statistics_type not in PT_REDUCERS_MAP: + raise nncf.InternalError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if params.aggregator_type not in AGGREGATORS_MAP: + raise nncf.InternalError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + statistic_type = params.statistics_type + if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + # TODO(dlyakhov): merge two quantile aggregators in one + if container_key == PTMinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) + else: + if use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) + + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + "tensor_processor": PTNNCFCollectorTensorProcessor, + } + if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector + + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: + return node.metatype.weight_port_ids + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: + weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) + weight = nncf_graph.get_previous_nodes(weighted_node)[target_point.input_port_id] + return weight.node_name + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + # If the nodes share one weight tensor, we should have only one quantizer on that + return weight_name not in quantized_weight_names + + @staticmethod + def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerConfig: + return config + + @staticmethod + def _get_input_scale_shape( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + is_weights = target_point.is_weight_target_point() + if is_weights: + # TODO: support transpose conv/ make channel_idx common + channel_idx = 0 + else: + channel_idx = 1 # channel dim for activations + + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + scale_shape = tuple( + get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) + ) + + return input_shape, scale_shape, channel_idx + + @staticmethod + def _create_quantizer( + quantizer_config: QuantizerConfig, + scale_shape: Tuple, + parameters: FakeQuantizeParameters, + target_type: TargetType, + ) -> BaseQuantizer: + mode = quantizer_config.mode + quantizer_cls = QUANTIZATION_MODULES.get(mode) + narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC + quantizer_spec = PTQuantizerSpec.from_config( + quantizer_config, + narrow_range=narrow_range, + scale_shape=scale_shape, + half_range=False, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=None, + ) + quantizer = quantizer_cls(quantizer_spec) + + # Fill it with minmax + FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) + # Convert to the torch fake quantizer + torch_fq = convert_to_torch_fakequantizer(quantizer) + return torch_fq + + @staticmethod + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: + if isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) + input_range = parameters.input_high - parameters.input_low + # Subtract eps from the input_range to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) + else: + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + # Subtract eps from the scale to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) + + @staticmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> FXApplyTransformationCommand: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_point, quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_point.target_type + ) + transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) + return FXApplyTransformationCommand(transformation) + + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + + transformations = [] + for tp in target_points: + transformation = qdq_insertion_tranformation_builder(quantizer, [tp]) + transformations.append(FXApplyTransformationCommand(transformation)) + return transformations + + @staticmethod + def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: + types = [] + if model_type == ModelType.TRANSFORMER: + types = [ + om.PTAddMetatype, + om.PTPowerMetatype, + om.PTSubMetatype, + om.PTAvgPool2dMetatype, + om.PTAvgPool3dMetatype, + om.PTMeanMetatype, + om.PTSumMetatype, + om.PTReduceL2, + om.PTDivMetatype, + om.PTMaxMetatype, + om.PTSqueezeMetatype, + om.PTLayerNormMetatype, + om.PTModuleLayerNormMetatype, + om.PTGroupNormMetatype, + om.PTModuleGroupNormMetatype, + # Batchnorm + om.PTBatchNormMetatype, + om.PTModuleBatchNormMetatype, + ] + if device != TargetDevice.CPU_SPR: + types.append(om.PTMulMetatype) + return types + + @staticmethod + def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]: + return [] + + @staticmethod + def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + retval = set() + for node in nncf_graph.get_all_nodes(): + if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]: + retval.add(node) + return list(retval) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 6455eea8535..f91f1a08e03 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -209,7 +209,21 @@ def quantize( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) + if backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.quantization.quantize_model import quantize_impl + return quantize_impl( + model=model, + calibration_dataset=calibration_dataset, + mode=mode, + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 6f01e44f49b..fa7de64be0f 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -341,7 +341,7 @@ def patch_torch_operators(): functions_to_patch = {} for namespace in NamespaceTarget: - if namespace == NamespaceTarget.EXTERNAL: + if namespace in [NamespaceTarget.ATEN, NamespaceTarget.EXTERNAL]: continue functions_to_patch[namespace] = get_all_functions_from_namespace(namespace) diff --git a/nncf/torch/dynamic_graph/structs.py b/nncf/torch/dynamic_graph/structs.py index c767790a92c..d8cf563107f 100644 --- a/nncf/torch/dynamic_graph/structs.py +++ b/nncf/torch/dynamic_graph/structs.py @@ -22,6 +22,7 @@ class NamespaceTarget(Enum): TORCH_TENSOR = "torch.tensor" TORCH_NN_PARAMETER = "torch.nn.parameter" TORCH = "torch" + ATEN = "aten" EXTERNAL = "external_function" diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index db09f1edab5..6787842a098 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -55,6 +55,7 @@ class PTOperatorMetatype(OperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: [], NamespaceTarget.TORCH: [], + NamespaceTarget.ATEN: [], } subtypes: List[Type["PTOperatorMetatype"]] = [] @@ -527,7 +528,7 @@ class PTGELUMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTSILUMetatype(PTOperatorMetatype): name = "SiluOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"]} + module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"], NamespaceTarget.ATEN: ["silu_"]} @PT_OPERATOR_METATYPES.register() @@ -546,6 +547,7 @@ class PTAddMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["add", "__add__", "__iadd__", "__radd__"], NamespaceTarget.TORCH: ["add"], + NamespaceTarget.ATEN: ["add_"], } hw_config_names = [HWConfigOpName.ADD] num_expected_input_edges = 2 @@ -557,6 +559,7 @@ class PTSubMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["sub", "__sub__", "__isub__", "__rsub__"], NamespaceTarget.TORCH: ["sub"], + NamespaceTarget.ATEN: ["sub_"], } hw_config_names = [HWConfigOpName.SUBTRACT] num_expected_input_edges = 2 @@ -690,13 +693,19 @@ class PTThresholdMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], + } @PT_OPERATOR_METATYPES.register() class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], + } subtypes = [PTModuleBatchNormMetatype] weight_port_ids = [3] bias_port_id = 4 @@ -825,10 +834,17 @@ class PTGatherMetatype(PTOperatorMetatype): name = "GatherOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"], - NamespaceTarget.TORCH: ["gather", "index_select", "where"], + NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"], + NamespaceTarget.ATEN: ["slice"], } +@PT_OPERATOR_METATYPES.register() +class PTScaledDotProductAttention(PTOperatorMetatype): + name = "scaled_dot_product_attention" + module_to_function_names = {NamespaceTarget.ATEN: ["scaled_dot_product_attention"]} + + @PT_OPERATOR_METATYPES.register() class PTScatterMetatype(PTOperatorMetatype): name = "ScatterOp" @@ -840,7 +856,7 @@ class PTReshapeMetatype(PTOperatorMetatype): name = "ReshapeOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["reshape", "view", "flatten", "unsqueeze"], - NamespaceTarget.TORCH: ["flatten", "unsqueeze"], + NamespaceTarget.TORCH: ["flatten", "unflatten", "unsqueeze"], } hw_config_names = [HWConfigOpName.RESHAPE, HWConfigOpName.UNSQUEEZE, HWConfigOpName.FLATTEN] @@ -862,6 +878,7 @@ class PTSplitMetatype(PTOperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"], NamespaceTarget.TORCH: ["split", "chunk", "unbind"], + NamespaceTarget.ATEN: ["split_with_sizes"], } hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK] @@ -1027,7 +1044,10 @@ class PTSqrtMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTInterpolateMetatype(PTOperatorMetatype): name = "InterpolateOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"], + NamespaceTarget.ATEN: ["upsample_nearest2d", "upsample_nearest_exact2d"], + } hw_config_names = [HWConfigOpName.INTERPOLATE] num_expected_input_edges = 1 diff --git a/nncf/torch/graph/pattern_operations.py b/nncf/torch/graph/pattern_operations.py index d9957871d87..dc0e5b43af2 100644 --- a/nncf/torch/graph/pattern_operations.py +++ b/nncf/torch/graph/pattern_operations.py @@ -67,7 +67,7 @@ ) ARITHMETIC_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"], + GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__", "add_"], GraphPattern.LABEL_ATTR: "ARITHMETIC", } diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index fcab0a20f88..de8eeebee1f 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -19,7 +19,6 @@ import numpy as np import openvino as ov import torch -from datasets import load_dataset from memory_profiler import memory_usage from optimum.exporters.openvino.convert import export_from_model from optimum.intel.openvino import OVModelForCausalLM @@ -28,6 +27,7 @@ from whowhatbench import Evaluator import nncf +from datasets import load_dataset from tests.post_training.pipelines.base import BackendType from tests.post_training.pipelines.base import BaseTestPipeline from tests.post_training.pipelines.base import StatsFromOutput diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index 234c05e6de8..06c0ad1b64c 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -24,16 +24,16 @@ from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup -from nncf.experimental.tensor import Tensor -from nncf.experimental.tensor import functions as fn +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.fake_quantize import calculate_quantizer_parameters from nncf.quantization.fake_quantize import get_quantizer_narrow_range +from nncf.tensor import Tensor +from nncf.tensor import functions as fns from nncf.torch.model_creation import wrap_model from nncf.torch.statistics.aggregator import PTStatisticsAggregator -from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic from tests.post_training.test_templates.test_calculate_quantizer_parameters import TemplateTestFQParams from tests.torch.helpers import get_all_inputs_for_graph_node from tests.torch.helpers import get_nodes_by_type @@ -210,8 +210,8 @@ def test_quantizer_params_asym(case_to_test: CaseSymParams): ) quantizer = PTMinMaxAlgoBackend._create_quantizer(qconfig, scale_shape, fq_params, target_type) assert quantizer.levels == fq_params.levels - assert fn.allclose(quantizer.input_low.data, case_to_test.ref_inp_low) - assert fn.allclose(quantizer.input_range.data, case_to_test.ref_inp_range) + assert fns.allclose(quantizer.input_low.data, case_to_test.ref_inp_low) + assert fns.allclose(quantizer.input_range.data, case_to_test.ref_inp_range) class LinearTestModel(nn.Module): @@ -268,7 +268,9 @@ def calculate_statistics(data, mode, qgroup, half_range=False): else: max_values = np.amax(data, axes) - statistics = PTMinMaxTensorStatistic(min_values=torch.tensor(min_values), max_values=torch.tensor(max_values)) + statistics = MinMaxTensorStatistic( + min_values=Tensor(torch.tensor(min_values)), max_values=Tensor(torch.tensor(max_values)) + ) signedness_to_force = True if qgroup == QuantizerGroup.WEIGHTS else None qconfig = QuantizerConfig(num_bits=8, mode=mode, per_channel=per_ch, signedness_to_force=signedness_to_force) narrow_range = get_quantizer_narrow_range(qconfig, qgroup) @@ -340,11 +342,10 @@ def test_quantizer_parameters_export(tmp_path: Path, _seed): for name, param in fq_params.items(): assert name in torch_ptq_params - assert fn.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) - assert fn.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) + assert fns.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) + assert fns.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) class TestFQParams(TemplateTestFQParams): - @property - def tensor_statistic(self): - return PTMinMaxTensorStatistic + def to_nncf_tensor(self, t): + return Tensor(torch.tensor(t)) diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index 93281435104..aa427735b53 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -15,6 +15,7 @@ import pytest import torch +from nncf import Dataset from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -22,7 +23,7 @@ from nncf.torch.layers import NNCF_RNN from nncf.torch.layers import LSTMCellNNCF from tests.post_training.test_templates.helpers import EmbeddingModel -from tests.post_training.test_templates.helpers import get_static_dataset +from tests.post_training.test_templates.helpers import ScaledDotProductAttentionModel from tests.torch import test_models from tests.torch.quantization.test_algo_quantization import SharedLayersModel from tests.torch.test_compressed_graph import ModelDesc @@ -49,6 +50,14 @@ def get_model_name(description): TEST_MODELS_DESC = [ (ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}), + ( + ModelDesc( + "scaled_dot_product_attention_model", + ScaledDotProductAttentionModel, + {"query": [1, 8, 16], "key": [1, 8, 16], "value": [1, 8, 16]}, + ), + {}, + ), (ModelDesc("shared_model", SharedLayersModel, [1, 1, 5, 6]), {}), (ModelDesc("alexnet", test_models.AlexNet, [1, 3, 32, 32]), {}), (ModelDesc("lenet", test_models.LeNet, [1, 3, 32, 32]), {}), @@ -96,18 +105,21 @@ def get_model_name(description): def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_parameters, graph_dir, mocker): model = desc.model_builder() - nncf_network = wrap_model(model, torch.ones(desc.input_sample_sizes), trace_parameters=True) + if isinstance(desc.input_sample_sizes, dict): + example_input = {} + for name, size in desc.input_sample_sizes.items(): + example_input[name] = torch.ones(size) + else: + example_input = torch.ones(desc.input_sample_sizes) + + nncf_network = wrap_model(model, example_input, trace_parameters=True) quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) quantization_parameters["subset_size"] = 1 quantization_algorithm = PostTrainingQuantization(**quantization_parameters) - def transform_fn(input_) -> torch.Tensor: - return torch.tensor(input_[0]) - quantized_model = quantization_algorithm.apply( nncf_network, nncf_network.nncf.get_graph(), - dataset=get_static_dataset(desc.input_sample_sizes, transform_fn, None), + dataset=Dataset([example_input]), ) - check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir) diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py index 84cb20fb9ea..c657b222802 100644 --- a/tests/torch/ptq/test_reducers_and_aggregators.py +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -19,7 +19,8 @@ import nncf from nncf.common.graph.layer_attributes import Dtype from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.torch.tensor import PTNNCFTensor +from nncf.tensor import Tensor +from nncf.tensor import functions as fns from nncf.torch.tensor_statistics.algo import create_register_input_hook from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer from nncf.torch.tensor_statistics.collectors import PTAbsQuantileReducer @@ -28,15 +29,11 @@ from nncf.torch.tensor_statistics.collectors import PTMeanPerChanelReducer from nncf.torch.tensor_statistics.collectors import PTMeanReducer from nncf.torch.tensor_statistics.collectors import PTMinReducer -from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor from nncf.torch.tensor_statistics.collectors import PTQuantileReducer -from tests.common.experimental.test_reducers_and_aggregators import TemplateTestReducersAggreagtors +from tests.common.experimental.test_reducers_and_aggregators import TemplateTestReducersAggregators -class BaseTestReducersAggregators(TemplateTestReducersAggreagtors, ABC): - @pytest.fixture - def tensor_processor(self): - return PTNNCFCollectorTensorProcessor +class BaseTestReducersAggregators(TemplateTestReducersAggregators, ABC): def _get_torch_tensor(self, x: np.ndarray, dtype: Optional[Dtype] = None): torch_tensor = torch.tensor(x) @@ -80,7 +77,7 @@ def cast_tensor(self, tensor, dtype: Dtype): class TestCPUReducersAggregators(BaseTestReducersAggregators): def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): - return PTNNCFTensor(self._get_torch_tensor(x, dtype=dtype).cpu()) + return Tensor(self._get_torch_tensor(x, dtype=dtype).cpu()) def all_close(self, val: torch.Tensor, ref) -> bool: assert not val.is_cuda @@ -91,23 +88,23 @@ def all_close(self, val: torch.Tensor, ref) -> bool: @pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available in current environment") class TestCudaReducersAggregators(BaseTestReducersAggregators): def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): - return PTNNCFTensor(self._get_torch_tensor(x, dtype=dtype).cuda()) + return Tensor(self._get_torch_tensor(x, dtype=dtype).cuda()) def all_close(self, val: torch.Tensor, ref) -> bool: assert val.is_cuda return super().all_close(val, ref) -@pytest.mark.parametrize("size,ref", [(16_000_000, 1_600_000.8750), (17_000_000, 1_700_000.7500)]) +@pytest.mark.parametrize("size,ref", [(16_000_000, 1_600_000.8), (17_000_000, 1_700_000.8)]) def test_quantile_percentile_function(use_cuda, size, ref): if use_cuda and not torch.cuda.is_available(): pytest.skip("Cuda is not available in current environment") device = "cuda" if use_cuda else "cpu" - tensor = PTNNCFTensor(torch.arange(1, size, 1).float().to(device)) - res_quantile = PTNNCFCollectorTensorProcessor.quantile(tensor, [0.1], axis=0) - res_percentile = PTNNCFCollectorTensorProcessor.percentile(tensor, [10], axis=0) - assert len(res_quantile) == len(res_percentile) == 1 - for tensor in [res_quantile[0].tensor, res_percentile[0].tensor]: + tensor = Tensor(torch.arange(1, size, 1).float().to(device)) + res_quantile = fns.quantile(tensor, [0.1], axis=0) + res_percentile = fns.percentile(tensor, [10], axis=0) + assert res_quantile.shape[0] == res_quantile.shape[0] == 1 + for tensor in [res_quantile[0].data, res_percentile[0].data]: assert tensor == ref assert tensor.is_cuda == (device == "cuda") @@ -117,10 +114,10 @@ def test_median_function(use_cuda, size, ref): if use_cuda and not torch.cuda.is_available(): pytest.skip("Cuda is not available in current environment") device = "cuda" if use_cuda else "cpu" - tensor = PTNNCFTensor(torch.arange(1, size, 1).float().to(device)) - res = PTNNCFCollectorTensorProcessor.median(tensor, axis=0) - assert res.tensor == ref - assert res.tensor.is_cuda == (device == "cuda") + tensor = Tensor(torch.arange(1, size, 1).float().to(device)) + res = fns.median(tensor, axis=0) + assert res.data == ref + assert res.data.is_cuda == (device == "cuda") def test_create_register_input_hook_with_return_type(mocker): @@ -133,7 +130,5 @@ def test_create_register_input_hook_with_return_type(mocker): mocker = collector.register_input_for_all_reducers mocker.assert_called_once() attr = mocker.call_args_list[0][0][0] - assert isinstance(attr, PTNNCFTensor) - assert attr.tensor == torch.tensor( - 1, - ) + assert isinstance(attr, Tensor) + assert attr.data == torch.tensor(1) diff --git a/tests/torch/ptq/test_statistic_collector.py b/tests/torch/ptq/test_statistic_collector.py index ffc5a828625..50930a289ec 100644 --- a/tests/torch/ptq/test_statistic_collector.py +++ b/tests/torch/ptq/test_statistic_collector.py @@ -9,50 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type import numpy as np -import pytest import torch -from nncf.common.tensor import NNCFTensor -from nncf.common.tensor_statistics.statistics import MeanTensorStatistic -from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic -from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic -from nncf.common.tensor_statistics.statistics import RawTensorStatistic -from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.statistics import PTMeanTensorStatistic -from nncf.torch.tensor_statistics.statistics import PTMedianMADTensorStatistic -from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic -from nncf.torch.tensor_statistics.statistics import PTPercentileTensorStatistic +from nncf.tensor import Tensor from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector class TestPTStatisticCollector(TemplateTestStatisticCollector): - def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor: - return PTNNCFTensor(torch.tensor(value)) - - @pytest.fixture - def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: - return PTMinMaxTensorStatistic - - @pytest.fixture - def mean_statistic_cls(self) -> Type[MeanTensorStatistic]: - return PTMeanTensorStatistic - - @pytest.fixture - def median_mad_statistic_cls(self) -> Type[MedianMADTensorStatistic]: - return PTMedianMADTensorStatistic - - @pytest.fixture - def percentile_statistic_cls(self) -> Type[PercentileTensorStatistic]: - return PTPercentileTensorStatistic - - @pytest.fixture - def raw_statistic_cls(self) -> Type[RawTensorStatistic]: - raise NotImplementedError() - - @pytest.mark.skip - def test_raw_max_stat_building(self, raw_statistic_cls: RawTensorStatistic): - pass + def get_nncf_tensor(self, value: np.ndarray) -> Tensor: + return Tensor(torch.tensor(value)) diff --git a/tests/torch/ptq/test_tensor_collector_batch_size.py b/tests/torch/ptq/test_tensor_collector_batch_size.py index 5beff90e67a..6c4b6aae6e6 100644 --- a/tests/torch/ptq/test_tensor_collector_batch_size.py +++ b/tests/torch/ptq/test_tensor_collector_batch_size.py @@ -14,26 +14,11 @@ import torch from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP -from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP -from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic from tests.common.experimental.test_tensor_collector_batch_size import TemplateTestTensorCollectorBatchSize class TestTensorCollectorBatchSize(TemplateTestTensorCollectorBatchSize): - @staticmethod - def get_tensor_statistics_class(): - return PTMinMaxTensorStatistic - - @staticmethod - def get_tensor_processor(): - return PTNNCFCollectorTensorProcessor() - - @staticmethod - def get_nncf_tensor_class(): - return PTNNCFTensor - @pytest.fixture(params=PT_REDUCERS_MAP.values()) def reducers(self, request) -> bool: return request.param diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 8cb5e00932f..30c704e5435 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -17,7 +17,8 @@ from nncf import SensitivityMetric from nncf.quantization import compress_weights from nncf.torch import wrap_model -from nncf.torch.quantization.layers import WeightsDecompressor +from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import SymmetricWeightsDecompressor DATA_BASED_SENSITIVITY_METRICS = ( SensitivityMetric.HESSIAN_INPUT_ACTIVATION, @@ -28,7 +29,7 @@ ALL_SENSITIVITY_METRICS = DATA_BASED_SENSITIVITY_METRICS + (SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,) -SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM) +SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) UNSUPPORTED_MODES = ( CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM, @@ -106,12 +107,14 @@ def forward(self, input_): return x -def test_compress_weights(): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_weights(mode): model = ShortTransformer(5, 10) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -119,22 +122,26 @@ def test_compress_weights(): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules -def test_compress_weights_functional_model(): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_weights_functional_model(mode): model = FunctionalModel() + decompressor_type = ( + SymmetricWeightsDecompressor if mode == CompressWeightsMode.INT8_SYM else AsymmetricWeightsDecompressor + ) input_ids = torch.randint(0, 10, [1, 3, 300, 300]) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 for layer in compressed_model.nncf.external_op.values(): - if isinstance(layer, WeightsDecompressor): + if isinstance(layer, decompressor_type): n_compressed_weights += 1 assert n_compressed_weights == 4 @@ -158,12 +165,14 @@ def test_compress_weights_conv(): assert n_compressed_weights == n_target_modules -def test_compress_shared_weights(mocker): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_shared_weights(mocker, mode): model = ShortTransformer(5, 10, share_weights=True) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -171,7 +180,7 @@ def test_compress_shared_weights(mocker): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules @@ -201,8 +210,9 @@ def forward(self, input): {"all_layers": True}, {"all_layers": False}, *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), - {"dataset": "anything"}, - {"ignored_scope": "anything"}, + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): @@ -214,7 +224,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) -def test_raise_error_with_not_int8_asym(mode): +def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) diff --git a/tests/torch/sparsity/movement/helpers/run_recipe.py b/tests/torch/sparsity/movement/helpers/run_recipe.py index 77b3140a967..383552932d5 100644 --- a/tests/torch/sparsity/movement/helpers/run_recipe.py +++ b/tests/torch/sparsity/movement/helpers/run_recipe.py @@ -20,7 +20,6 @@ import torch.nn import torch.nn.functional as F import torch.utils.data -from datasets import Dataset from transformers import AutoModelForAudioClassification from transformers import AutoModelForImageClassification from transformers import AutoModelForSequenceClassification @@ -34,6 +33,7 @@ from transformers import SwinConfig from transformers import Wav2Vec2Config +from datasets import Dataset from nncf import NNCFConfig from nncf.experimental.torch.sparsity.movement.scheduler import MovementSchedulerParams from nncf.torch.dynamic_graph.io_handling import FillerInputElement diff --git a/tests/torch/sparsity/movement/helpers/trainer.py b/tests/torch/sparsity/movement/helpers/trainer.py index 89ffeb6c865..2af37c5b2f4 100644 --- a/tests/torch/sparsity/movement/helpers/trainer.py +++ b/tests/torch/sparsity/movement/helpers/trainer.py @@ -14,7 +14,6 @@ import numpy as np import torch -from datasets import Dataset # pylint: disable=no-name-in-module from transformers import TrainingArguments from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback @@ -22,6 +21,7 @@ from transformers.trainer_callback import TrainerState from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset # pylint: disable=no-name-in-module from nncf.api.compression import CompressionAlgorithmController from nncf.common.compression import BaseCompressionAlgorithmController from nncf.common.utils.tensorboard import prepare_for_tensorboard diff --git a/tests/torch/sparsity/movement/test_model_saving.py b/tests/torch/sparsity/movement/test_model_saving.py index 901104fcd7e..5b3c463ec8f 100644 --- a/tests/torch/sparsity/movement/test_model_saving.py +++ b/tests/torch/sparsity/movement/test_model_saving.py @@ -18,7 +18,6 @@ import pytest import torch from addict import Dict -from datasets import Dataset from onnx import numpy_helper from openvino._offline_transformations import apply_fused_names_cleanup from openvino._offline_transformations import apply_moc_transformations @@ -29,6 +28,7 @@ from scipy.special import softmax from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset from nncf.torch import create_compressed_model from nncf.torch.checkpoint_loading import load_state from tests.torch.helpers import PTTensorListComparator diff --git a/tests/torch/sparsity/movement/training_scripts/run_glue.py b/tests/torch/sparsity/movement/training_scripts/run_glue.py index 360832a5bb7..d0f5b14269e 100644 --- a/tests/torch/sparsity/movement/training_scripts/run_glue.py +++ b/tests/torch/sparsity/movement/training_scripts/run_glue.py @@ -12,12 +12,13 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple -import datasets import evaluate import jstyleson import numpy as np from transformers.training_args import ParallelMode +import datasets + # isort: off from nncf import NNCFConfig from nncf.api.compression import CompressionAlgorithmController diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 96b72d48ed0..0c1075d9bf1 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -26,7 +26,7 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend -from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend +from nncf.quantization.algorithms.min_max.torch_backend import FXMinMaxAlgoBackend from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.graph.graph import PTTargetPoint @@ -60,8 +60,8 @@ def get_nncf_network(self): class TestStatisticsAggregator(TemplateTestStatisticsAggregator): @staticmethod - def get_min_max_algo_backend_cls() -> Type[PTMinMaxAlgoBackend]: - return PTMinMaxAlgoBackend + def get_min_max_algo_backend_cls() -> Type[FXMinMaxAlgoBackend]: + return FXMinMaxAlgoBackend def get_bias_correction_algo_backend_cls(self) -> None: pytest.skip("PTBiasCorrectionAlgoBackend is not implemented") @@ -98,7 +98,7 @@ def get_target_point(target_type: TargetType): if target_type == TargetType.OPERATION_WITH_WEIGHTS: target_node_name = CONV_NODE_NAME port_id = 1 - return PTMinMaxAlgoBackend.target_point(target_type, target_node_name, port_id) + return FXMinMaxAlgoBackend.target_point(target_type, target_node_name, port_id) def get_target_point_cls(self): return PTTargetPoint @@ -277,7 +277,7 @@ def fn(x): target_point = self.get_target_point(test_parameters.target_type) model = self.__add_fn_to_model(model, target_point, fn) - nested_target_point = PTMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0) + nested_target_point = FXMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0) model = self.__add_fn_to_model(model, nested_target_point, fn) # Check hook inserted correctly diff --git a/tests/torch_fx/__init__.py b/tests/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/tests/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/torch_fx/helpers.py b/tests/torch_fx/helpers.py new file mode 100644 index 00000000000..8bbc721e0fa --- /dev/null +++ b/tests/torch_fx/helpers.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from fastdownload import FastDownload + + +class TinyImagenetDatasetManager: + DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" + DATASET_PATH = "~/.cache/nncf/tests/datasets" + + def __init__(self, image_size: int, batch_size: int) -> None: + self.image_size = image_size + self.batch_size = batch_size + + @staticmethod + def download_dataset() -> Path: + downloader = FastDownload(base=TinyImagenetDatasetManager.DATASET_PATH, archive="downloaded", data="extracted") + return downloader.get(TinyImagenetDatasetManager.DATASET_URL) + + @staticmethod + def prepare_tiny_imagenet_200(dataset_dir: Path): + # Format validation set the same way as train set is formatted. + val_data_dir = dataset_dir / "val" + val_images_dir = val_data_dir / "images" + if not val_images_dir.exists(): + return + + val_annotations_file = val_data_dir / "val_annotations.txt" + with open(val_annotations_file, "r") as f: + val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines()) + for image_filename, image_label in val_annotation_data: + from_image_filepath = val_images_dir / image_filename + to_image_dir = val_data_dir / image_label + if not to_image_dir.exists(): + to_image_dir.mkdir() + to_image_filepath = to_image_dir / image_filename + from_image_filepath.rename(to_image_filepath) + val_annotations_file.unlink() + val_images_dir.rmdir() + + def create_data_loaders(self): + dataset_path = TinyImagenetDatasetManager.download_dataset() + + TinyImagenetDatasetManager.prepare_tiny_imagenet_200(dataset_path) + print(f"Successfully downloaded and prepared dataset at: {dataset_path}") + + train_dir = dataset_path / "train" + val_dir = dataset_path / "val" + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.ToTensor(), + normalize, + ] + ), + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, sampler=None + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True + ) + + # Creating separate dataloader with batch size = 1 + # as dataloaders with batches > 1 are not supported yet. + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True + ) + + return train_loader, val_loader, calibration_dataset diff --git a/tests/torch_fx/requirements.txt b/tests/torch_fx/requirements.txt new file mode 100644 index 00000000000..99ee43ce754 --- /dev/null +++ b/tests/torch_fx/requirements.txt @@ -0,0 +1 @@ +fastdownload==0.0.7 \ No newline at end of file diff --git a/tests/torch_fx/test_sanity.py b/tests/torch_fx/test_sanity.py new file mode 100644 index 00000000000..197c2f95472 --- /dev/null +++ b/tests/torch_fx/test_sanity.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import openvino.torch # noqa +import pytest +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.models as models +from torch._export import capture_pre_autograd_graph + +import nncf +from nncf.common.logging.track_progress import track +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching +from tests.torch_fx.helpers import TinyImagenetDatasetManager + +IMAGE_SIZE = 64 +BATCH_SIZE = 128 + + +@dataclass +class SanitySampleCase: + model_id: str + checkpoint_url: str + top1_int8_ref: float + ref_num_q: int + ref_num_dq: int + + +MODELS = ( + SanitySampleCase( + "resnet18", + "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", + 55.23, + 51, + 58, + ), +) + + +def get_model(model_id: str, checkpoint_url: str, device: torch.device) -> torch.nn.Module: + num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet + model = getattr(models, model_id)(weights=None) + # Update the last FC layer for Tiny ImageNet number of classes. + model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) + model.to(device) + checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device("cpu"), progress=False) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float: + top1_sum = 0.0 + with torch.no_grad(): + for images, target in track(val_loader, total=len(val_loader), description="Validation:"): + images = images.to(device) + target = target.to(device) + + # Compute output. + output = model(images) + + # Measure accuracy and record loss. + [acc1] = accuracy(output, target, topk=(1,)) + top1_sum += acc1.item() + + num_samples = len(val_loader) + top1_avg = top1_sum / num_samples + return top1_avg + + +def accuracy(output: torch.Tensor, target: torch.tensor, topk: Tuple[int, ...] = (1,)): + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def count_q_dq(model: torch.fx.GraphModule): + q, dq = 0, 0 + for node in model.graph.nodes: + if node.op == "call_function" and hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + if node_type in ["quantize_per_tensor", "quantize_per_channel"]: + q += 1 + elif node_type in ["dequantize_per_tensor", "dequantize_per_channel"]: + dq += 1 + return q, dq + + +@pytest.mark.parametrize("test_case", MODELS) +def test_sanity(test_case: SanitySampleCase): + with disable_patching(): + device = torch.device("cpu") + model = get_model(test_case.model_id, test_case.checkpoint_url, device) + _, val_dataloader, calibration_dataset = TinyImagenetDatasetManager( + IMAGE_SIZE, BATCH_SIZE + ).create_data_loaders() + + def transform_fn(data_item): + return data_item[0].to(device) + + calibration_dataset = nncf.Dataset(calibration_dataset, transform_fn) + + with torch.no_grad(): + ex_input = next(iter(calibration_dataset.get_inference_data())) + model.eval() + exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + quantized_model = torch.compile(quantized_model, backend="openvino") + + top1_int8 = validate(val_dataloader, quantized_model, device) + assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=1e-2) + + num_q, num_dq = count_q_dq(quantized_model) + assert num_q == test_case.ref_num_q + assert num_dq == test_case.ref_num_dq diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py new file mode 100644 index 00000000000..7bd0addf02e --- /dev/null +++ b/torch_compile_ex_release.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Enable torch inductor freezing feature first +import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + + +import argparse +import copy +import time +from collections import defaultdict + +import openvino.torch # noqa +import torch + +# Optional: using the C++ wrapper instead of default Python wrapper +import torch._inductor.config as config +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer + +from nncf.experimental.torch_fx.model_transformer import QPARAMPerChannel +from nncf.experimental.torch_fx.model_transformer import QPARAMSPerTensor +from nncf.experimental.torch_fx.model_transformer import insert_qdq_to_model +from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter # noqa + + +def get_exported_model_from_nn_module(module, example_inputs): + with torch.no_grad(): + return capture_pre_autograd_graph(module, example_inputs) + + +NNCF_IMPL = True + + +def get_qsetup(exported_model, example_inputs): + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + prepared_model(*example_inputs) + converted_model = convert_pt2e(prepared_model) + g = FxGraphDrawer(converted_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled.svg") + qsetup = defaultdict(lambda: dict()) + + for node in converted_model.graph.nodes: + if "dequantize" in node.name: + quantize = node.all_input_nodes[0] + # place = "activations" + # if len(quantize.all_input_nodes) > 1: + # place = "weights" + if "per_tensor" in node.name: + params = QPARAMSPerTensor(*node.args[1:]) + else: + params = [] + for i in range(1, 3): + name = node.args[i].target + params.append(getattr(converted_model, name)) + params = QPARAMPerChannel(*(params + list(node.args[3:]))) + + target_node_name = quantize.all_input_nodes[0].name + qsetup[target_node_name] = params + return qsetup + + +def quantize(model, example_inputs): + if NNCF_IMPL: + # Use NNCF here on exported model + # to create a quantized model which is compatible with + # convert_pt2e function + pass + # 1. Convert torch.graph to NNCFGraph. + # # 2. Analize nncf grpah for SQ/CA + # # 3. Collect statistics + # # 4. Update params + # 5. Analize nncf graph for quantization + # 6. Insert observers + # 7. prepared_model(*example_inputs) + # 8. convert_pt2e(prepared_model) + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf") + g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg") + return quantized_model + + else: + # g = FxGraphDrawer(exported_model, "resnet18") + # g.get_dot_graph().write_svg("resnet18_compiled.svg") + + # MOCK NNCF QUANTIZATION + exported_model = get_exported_model_from_nn_module(model, example_inputs) + qsetup = get_qsetup(exported_model, example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + exported_model = insert_qdq_to_model(exported_model, qsetup) + g = FxGraphDrawer(exported_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled_manually.svg") + return exported_model + + return None # converted_model + + +config.cpp_wrapper = True + + +def measure_time(model, example_inputs, num_iters): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def get_dummy_dataset(): + traced_bs = 1 + x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) + example_inputs = (x,) + return example_inputs + + +def main_nncf(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + quantized_model = nncf.quantize(model, calibration_dataset) + + import openvino as ov + + ov_model = ov.convert_model(quantized_model.cpu(), example_input=example_inputs[0]) + ov.serialize(ov_model, "./model_cache_nncf/model.xml") + + +def main(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + + converted_model = quantize(copy.deepcopy(model), example_inputs) + + print("original model execution time: ", measure_time(model, example_inputs, num_iters)) + + native_optimized_model_fp32 = torch.compile(model) + print( + "Torch Inductor FP32 model execution time: ", + measure_time(native_optimized_model_fp32, example_inputs, num_iters), + ) + + native_optimized_model_int8 = torch.compile(converted_model) + print( + "Torch Inductor INT8 model execution time: ", + measure_time(native_optimized_model_int8, example_inputs, num_iters), + ) + + ov_optimized_model_fp32 = torch.compile(model, backend="openvino") + print( + "Torch.compile OpenVINO FP32 model execution time: ", + measure_time(ov_optimized_model_fp32, example_inputs, num_iters), + ) + + ov_optimized_model_int8 = torch.compile( + converted_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"} + ) + print( + "Torch.compile OpenVINO INT8 model execution time: ", + measure_time(ov_optimized_model_int8, example_inputs, num_iters), + ) + + import intel_extension_for_pytorch # noqa + + ipex_optimized_model_fp32 = torch.compile(model, backend="ipex") + print( + "Torch.compile IPEX FP32 model execution time: ", + measure_time(ipex_optimized_model_fp32, example_inputs, num_iters), + ) + + ipex_optimized_model_int8 = torch.compile(converted_model, backend="ipex") + print( + "Torch.compile IPEX INT8 model execution time: ", + measure_time(ipex_optimized_model_int8, example_inputs, num_iters), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_iters", help="number of inference iterations", type=int, default=100) + parser.add_argument("--model", help="torchvision model name", type=str, default="resnet18") + args = parser.parse_args() + model_name = args.model + num_iters = args.num_iters + main(model_name, num_iters) + # main_nncf(model_name, num_iters) diff --git a/yolo_fx_bad_metrics_repro.py b/yolo_fx_bad_metrics_repro.py new file mode 100644 index 00000000000..b5c05d6bbcb --- /dev/null +++ b/yolo_fx_bad_metrics_repro.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import numpy as np +import torch +from tqdm import tqdm +from ultralytics.data.utils import check_det_dataset +from ultralytics.engine.validator import BaseValidator as Validator +from ultralytics.models.yolo import YOLO +from ultralytics.utils.torch_utils import de_parallel + + +def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None: + mp, mr, map50, mean_ap = ( + stats["metrics/precision(B)"], + stats["metrics/recall(B)"], + stats["metrics/mAP50(B)"], + stats["metrics/mAP50-95(B)"], + ) + s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "mAP@.5", "mAP@.5:.95") + print(s) + pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format + print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) + + +def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: + # custom = {"rect": True, "batch": 1} # method defaults + # rect: false forces to resize all input pictures to one size + custom = {"rect": False, "batch": 1} # method defaults + args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right + + validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) + stride = 32 # default stride + validator.stride = stride # used in get_dataloader() for padding + validator.data = check_det_dataset(data) + validator.init_metrics(de_parallel(model)) + + data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) + return validator, data_loader + + +def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: + with torch.no_grad(): + for batch in data_loader: + batch = validator.preprocess(batch) + preds = model(batch["img"]) + preds = validator.postprocess(preds) + validator.update_metrics(preds, batch) + stats = validator.get_stats() + return stats, validator.seen, validator.nt_per_class.sum() + + +def main(torch_fx): + # ultralytics @ git+https://github.com/THU-MIG/yolov10.git@2c36ab0f108efdd17c7e290564bb845ccb6844d8 + # pip install git+https://github.com/THU-MIG/yolov10.git + # pip install huggingface-hub + # yolo_model = YOLO("yolov10n.pt") + + yolo_model = YOLO("yolov8n") + + model_type = "torch" + model = yolo_model.model + if torch_fx: + model = torch.compile(model) + model_type = "FX" + print(f"FP32 {model_type} model validation results:") + validator, data_loader = prepare_validation(yolo_model, "coco128.yaml") + stats, total_images, total_objects = validate(model, tqdm(data_loader), validator) + print_statistics(stats, total_images, total_objects) + + +if __name__ == "__main__": + print("Torch model:") + main(torch_fx=False) + print("Torch FX model:") + main(torch_fx=True)