Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions examples/quantization_aware_training/torch/anomalib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,40 @@
import re
import subprocess
import tarfile
import warnings
from copy import deepcopy
from pathlib import Path
from urllib.request import urlretrieve

import torch
from anomalib import TaskType
from anomalib.data import MVTec
from anomalib.data.image import mvtec
from anomalib.data import MVTecAD
from anomalib.data.utils import DownloadInfo
from anomalib.data.utils import download
from anomalib.deploy import ExportType
from anomalib.engine import Engine
from anomalib.models import Stfpm

import nncf

warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)

HOME_PATH = Path.home()
DATASET_PATH = HOME_PATH / ".cache" / "nncf" / "datasets" / "mvtec"
CHECKPOINT_PATH = HOME_PATH / ".cache" / "nncf" / "models" / "anomalib"
ROOT = Path(__file__).parent.resolve()
FP32_RESULTS_ROOT = ROOT / "results" / "fp32"
INT8_RESULTS_ROOT = ROOT / "results" / "int8"
CHECKPOINT_URL = "https://storage.openvinotoolkit.org/repositories/nncf/examples/torch/anomalib/stfpm_mvtec.ckpt"
CHECKPOINT_URL = "https://storage.openvinotoolkit.org/repositories/nncf/examples/torch/anomalib/stfpm_mvtec_v2.ckpt"
USE_PRETRAINED = True

# Can be replaced to "from anomalib.data.datamodules.image.mvtecad import DOWNLOAD_INFO" on bump anomalib version
DOWNLOAD_INFO = DownloadInfo(
name="mvtecad",
url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f283/"
"download/420938113-1629960298/mvtec_anomaly_detection.tar.xz",
hashsum="cf4313b13603bec67abb49ca959488f7eedce2a9f7795ec54446c649ac98cd3d",
)


def download_and_extract(root: Path, info: download.DownloadInfo) -> None:
root.mkdir(parents=True, exist_ok=True)
Expand All @@ -53,10 +63,12 @@ def download_and_extract(root: Path, info: download.DownloadInfo) -> None:
downloaded_file_path.unlink()


def create_dataset(root: Path) -> MVTec:
def create_dataset(root: Path) -> MVTecAD:
if not root.exists():
download_and_extract(root, mvtec.DOWNLOAD_INFO)
return MVTec(root)
download_and_extract(root, DOWNLOAD_INFO)
data = MVTecAD(root, category="bottle")
data.setup()
return data


def run_benchmark(model_path: Path, shape: list[int]) -> float:
Expand Down Expand Up @@ -98,7 +110,7 @@ def main():
datamodule = create_dataset(root=DATASET_PATH)

# Create an engine for the original model
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=FP32_RESULTS_ROOT, devices=1)
engine = Engine(default_root_dir=FP32_RESULTS_ROOT, devices=1)
if USE_PRETRAINED:
# Load the pretrained checkpoint
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -132,7 +144,7 @@ def transform_fn(data_item):
quantized_model.model = quantized_inference_model

# Create engine for the quantized model
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=INT8_RESULTS_ROOT, max_epochs=1, devices=1)
engine = Engine(default_root_dir=INT8_RESULTS_ROOT, max_epochs=1, devices=1)

# Validate the quantized model
print("Test results for INT8 model after PTQ:")
Expand All @@ -141,7 +153,7 @@ def transform_fn(data_item):
###############################################################################
# Step 3: Fine tune the quantized model
print(os.linesep + "[Step 3] Fine tune the quantized model")

quantized_model.train()
engine.fit(model=quantized_model, datamodule=datamodule)
print("Test results for INT8 model after QAT:")
int8_test_results = engine.test(model=quantized_model, datamodule=datamodule)
Expand All @@ -151,12 +163,22 @@ def transform_fn(data_item):
print(os.linesep + "[Step 4] Export models")

# Export FP32 model to OpenVINO™ IR
fp32_ir_path = engine.export(model=model, export_type=ExportType.OPENVINO, export_root=FP32_RESULTS_ROOT)
fp32_ir_path = engine.export(
model=model,
export_type=ExportType.OPENVINO,
export_root=FP32_RESULTS_ROOT,
onnx_kwargs={"dynamo": False},
)
print(f"Original model path: {fp32_ir_path}")
fp32_size = get_model_size(fp32_ir_path)

# Export INT8 model to OpenVINO™ IR
int8_ir_path = engine.export(model=quantized_model, export_type=ExportType.OPENVINO, export_root=INT8_RESULTS_ROOT)
int8_ir_path = engine.export(
model=quantized_model,
export_type=ExportType.OPENVINO,
export_root=INT8_RESULTS_ROOT,
onnx_kwargs={"dynamo": False},
)
print(f"Quantized model path: {int8_ir_path}")
int8_size = get_model_size(int8_ir_path)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
anomalib[core,openvino]==1.0.1
matplotlib<3.10.0
numpy==1.26.4
safetensors<=0.5.3
anomalib==2.2.0
torch==2.9.0
openvino==2025.3.0
requests==2.32.5
matplotlib==3.10.7
numpy==2.2.6
safetensors==0.6.2
onnx==1.17.0
timm==1.0.22
3 changes: 1 addition & 2 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@
"fp32_model_size": 21.37990665435791,
"int8_model_size": 5.677968978881836,
"model_compression_rate": 3.7654144877995197
},
"xfail": "https://github.com/open-edge-platform/anomalib/issues/3121"
}
},
"fp8_llm_quantization": {
"backend": "openvino",
Expand Down
5 changes: 2 additions & 3 deletions tests/cross_fw/examples/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,13 @@ def set_torch_cuda_seed(seed: int = 42):


def quantization_aware_training_torch_anomalib(data: Union[str, None]):
from anomalib.data.image import mvtec

from examples.quantization_aware_training.torch.anomalib.main import DATASET_PATH as dataset_path
from examples.quantization_aware_training.torch.anomalib.main import DOWNLOAD_INFO
from examples.quantization_aware_training.torch.anomalib.main import main as anomalib_main

if data is not None and not dataset_path.exists():
dataset_path.mkdir(parents=True, exist_ok=True)
tar_file_path = Path(data) / mvtec.DOWNLOAD_INFO.url.split("/")[-1]
tar_file_path = Path(data) / DOWNLOAD_INFO.url.split("/")[-1]
with tarfile.open(tar_file_path) as tar_file:
tar_file.extractall(dataset_path)

Expand Down