|
| 1 | +# Copyright 2025 ETH Zurich and University of Bologna. |
| 2 | +# Licensed under the Apache License, Version 2.0, see LICENSE for details. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Federico Brancasi <[email protected]> |
| 6 | + |
| 7 | +import tarfile |
| 8 | +import urllib.request |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import brevitas.nn as qnn |
| 12 | +import pytest |
| 13 | +import torch |
| 14 | +import torch.nn as nn |
| 15 | +import torchvision |
| 16 | +import torchvision.transforms as transforms |
| 17 | +from brevitas.graph.calibrate import calibration_mode |
| 18 | +from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool |
| 19 | +from brevitas.graph.quantize import preprocess_for_quantize, quantize |
| 20 | +from brevitas.quant import ( |
| 21 | + Int8ActPerTensorFloat, |
| 22 | + Int8WeightPerTensorFloat, |
| 23 | + Int32Bias, |
| 24 | + Uint8ActPerTensorFloat, |
| 25 | +) |
| 26 | +from torch.utils.data import DataLoader, Subset |
| 27 | +from torchvision.datasets import ImageFolder |
| 28 | +from tqdm import tqdm |
| 29 | + |
| 30 | +from DeepQuant import brevitasToTrueQuant |
| 31 | + |
| 32 | + |
| 33 | +def evaluateModel(model, dataLoader, evalDevice, name="Model"): |
| 34 | + model.eval() |
| 35 | + correctTop1 = 0 |
| 36 | + correctTop5 = 0 |
| 37 | + total = 0 |
| 38 | + |
| 39 | + with torch.no_grad(): |
| 40 | + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): |
| 41 | + isTQ = "TQ" in name |
| 42 | + |
| 43 | + if isTQ: |
| 44 | + # FBRANCASI: Process different batches for the TQ model |
| 45 | + for i in range(inputs.size(0)): |
| 46 | + singleInput = inputs[i : i + 1].to(evalDevice) |
| 47 | + singleOutput = model(singleInput) |
| 48 | + if isinstance(singleOutput, tuple): |
| 49 | + singleOutput = singleOutput[0] |
| 50 | + |
| 51 | + _, predicted = singleOutput.max(1) |
| 52 | + if predicted.item() == targets[i].item(): |
| 53 | + correctTop1 += 1 |
| 54 | + |
| 55 | + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) |
| 56 | + if targets[i].item() in top5Pred[0].cpu().numpy(): |
| 57 | + correctTop5 += 1 |
| 58 | + |
| 59 | + total += 1 |
| 60 | + else: |
| 61 | + inputs = inputs.to(evalDevice) |
| 62 | + targets = targets.to(evalDevice) |
| 63 | + output = model(inputs) |
| 64 | + if isinstance(output, tuple): |
| 65 | + output = output[0] |
| 66 | + |
| 67 | + _, predicted = output.max(1) |
| 68 | + correctTop1 += (predicted == targets).sum().item() |
| 69 | + |
| 70 | + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) |
| 71 | + for i in range(targets.size(0)): |
| 72 | + if targets[i] in top5Pred[i]: |
| 73 | + correctTop5 += 1 |
| 74 | + |
| 75 | + total += targets.size(0) |
| 76 | + |
| 77 | + top1Accuracy = 100.0 * correctTop1 / total |
| 78 | + top5Accuracy = 100.0 * correctTop5 / total |
| 79 | + |
| 80 | + print( |
| 81 | + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " |
| 82 | + f"Top-5 Accuracy: {top5Accuracy:.2f}%" |
| 83 | + ) |
| 84 | + |
| 85 | + return top1Accuracy, top5Accuracy |
| 86 | + |
| 87 | + |
| 88 | +def calibrateModel(model, calibLoader): |
| 89 | + model.eval() |
| 90 | + with torch.no_grad(), calibration_mode(model): |
| 91 | + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): |
| 92 | + inputs = inputs.to("cpu") |
| 93 | + output = model(inputs) |
| 94 | + if isinstance(output, tuple): |
| 95 | + output = output[0] |
| 96 | + print("Calibration completed.") |
| 97 | + |
| 98 | + |
| 99 | +def prepareFQVitB32(): |
| 100 | + """Prepare a fake-quantized (FQ) ViT-B/32 model.""" |
| 101 | + baseModel = torchvision.models.vit_b_32( |
| 102 | + weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1 |
| 103 | + ) |
| 104 | + baseModel = baseModel.eval().to("cpu") |
| 105 | + |
| 106 | + computeLayerMap = { |
| 107 | + nn.Conv2d: ( |
| 108 | + qnn.QuantConv2d, |
| 109 | + { |
| 110 | + "input_quant": Int8ActPerTensorFloat, |
| 111 | + "weight_quant": Int8WeightPerTensorFloat, |
| 112 | + "output_quant": Int8ActPerTensorFloat, |
| 113 | + "bias_quant": Int32Bias, |
| 114 | + "bias": True, |
| 115 | + "return_quant_tensor": True, |
| 116 | + "output_bit_width": 8, |
| 117 | + "weight_bit_width": 8, |
| 118 | + }, |
| 119 | + ), |
| 120 | + nn.MultiheadAttention: ( |
| 121 | + qnn.QuantMultiheadAttention, |
| 122 | + { |
| 123 | + "in_proj_input_quant": Int8ActPerTensorFloat, |
| 124 | + "in_proj_weight_quant": Int8WeightPerTensorFloat, |
| 125 | + "in_proj_bias_quant": Int32Bias, |
| 126 | + "attn_output_weights_quant": Uint8ActPerTensorFloat, |
| 127 | + "q_scaled_quant": Int8ActPerTensorFloat, |
| 128 | + "k_transposed_quant": Int8ActPerTensorFloat, |
| 129 | + "v_quant": Int8ActPerTensorFloat, |
| 130 | + "out_proj_input_quant": Int8ActPerTensorFloat, |
| 131 | + "out_proj_weight_quant": Int8WeightPerTensorFloat, |
| 132 | + "out_proj_bias_quant": Int32Bias, |
| 133 | + "out_proj_output_quant": Int8ActPerTensorFloat, |
| 134 | + "return_quant_tensor": True, |
| 135 | + }, |
| 136 | + ), |
| 137 | + nn.Linear: ( |
| 138 | + qnn.QuantLinear, |
| 139 | + { |
| 140 | + "input_quant": Int8ActPerTensorFloat, |
| 141 | + "weight_quant": Int8WeightPerTensorFloat, |
| 142 | + "output_quant": Int8ActPerTensorFloat, |
| 143 | + "bias_quant": Int32Bias, |
| 144 | + "bias": True, |
| 145 | + "return_quant_tensor": True, |
| 146 | + "output_bit_width": 8, |
| 147 | + "weight_bit_width": 8, |
| 148 | + }, |
| 149 | + ), |
| 150 | + } |
| 151 | + |
| 152 | + quantActMap = { |
| 153 | + nn.GELU: ( |
| 154 | + qnn.QuantReLU, # FBRANCASI: Approximating GELU with QuantReLU |
| 155 | + { |
| 156 | + "act_quant": Uint8ActPerTensorFloat, |
| 157 | + "return_quant_tensor": True, |
| 158 | + "bit_width": 8, |
| 159 | + }, |
| 160 | + ), |
| 161 | + } |
| 162 | + |
| 163 | + quantIdentityMap = { |
| 164 | + "signed": ( |
| 165 | + qnn.QuantIdentity, |
| 166 | + { |
| 167 | + "act_quant": Int8ActPerTensorFloat, |
| 168 | + "return_quant_tensor": True, |
| 169 | + "bit_width": 8, |
| 170 | + }, |
| 171 | + ), |
| 172 | + "unsigned": ( |
| 173 | + qnn.QuantIdentity, |
| 174 | + { |
| 175 | + "act_quant": Uint8ActPerTensorFloat, |
| 176 | + "return_quant_tensor": True, |
| 177 | + "bit_width": 8, |
| 178 | + }, |
| 179 | + ), |
| 180 | + } |
| 181 | + |
| 182 | + dummyInput = torch.ones(1, 3, 224, 224).to("cpu") |
| 183 | + |
| 184 | + print("Preprocessing model for quantization...") |
| 185 | + baseModel = preprocess_for_quantize( |
| 186 | + baseModel, equalize_iters=20, equalize_scale_computation="range" |
| 187 | + ) |
| 188 | + |
| 189 | + print("Converting AdaptiveAvgPool to AvgPool...") |
| 190 | + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput) |
| 191 | + |
| 192 | + print("Quantizing model...") |
| 193 | + FQModel = quantize( |
| 194 | + graph_model=baseModel, |
| 195 | + compute_layer_map=computeLayerMap, |
| 196 | + quant_act_map=quantActMap, |
| 197 | + quant_identity_map=quantIdentityMap, |
| 198 | + ) |
| 199 | + |
| 200 | + return FQModel |
| 201 | + |
| 202 | + |
| 203 | +@pytest.mark.ModelTests |
| 204 | +def deepQuantTestVitB32Pretrained() -> None: |
| 205 | + HOME = Path.home() |
| 206 | + BASE = HOME / "Documents" / "ImagenetV2" |
| 207 | + TAR_URL = ( |
| 208 | + "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/" |
| 209 | + "imagenetv2-matched-frequency.tar.gz" |
| 210 | + ) |
| 211 | + TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz" |
| 212 | + EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val" |
| 213 | + |
| 214 | + if not TAR_PATH.exists(): |
| 215 | + BASE.mkdir(parents=True, exist_ok=True) |
| 216 | + print(f"Downloading ImageNetV2 from {TAR_URL}...") |
| 217 | + urllib.request.urlretrieve(TAR_URL, TAR_PATH) |
| 218 | + |
| 219 | + if not EXTRACT_DIR.exists(): |
| 220 | + print(f"Extracting to {EXTRACT_DIR}...") |
| 221 | + with tarfile.open(TAR_PATH, "r:*") as tar: |
| 222 | + for member in tqdm(tar.getmembers(), desc="Extracting files"): |
| 223 | + tar.extract(member, BASE) |
| 224 | + print("Extraction completed.") |
| 225 | + |
| 226 | + transformsVal = transforms.Compose( |
| 227 | + [ |
| 228 | + transforms.Resize(256), |
| 229 | + transforms.CenterCrop(224), |
| 230 | + transforms.ToTensor(), |
| 231 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 232 | + ] |
| 233 | + ) |
| 234 | + |
| 235 | + dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal) |
| 236 | + dataset.classes = sorted(dataset.classes, key=lambda x: int(x)) |
| 237 | + dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)} |
| 238 | + |
| 239 | + newSamples = [] |
| 240 | + for path, _ in dataset.samples: |
| 241 | + clsName = Path(path).parent.name |
| 242 | + newLabel = dataset.class_to_idx[clsName] |
| 243 | + newSamples.append((path, newLabel)) |
| 244 | + dataset.samples = newSamples |
| 245 | + dataset.targets = [s[1] for s in newSamples] |
| 246 | + |
| 247 | + # FBRANCASI: Optional, reduce number of example for faster validation |
| 248 | + DATASET_LIMIT = 256 |
| 249 | + dataset = Subset(dataset, list(range(DATASET_LIMIT))) |
| 250 | + print(f"Validation dataset size set to {len(dataset)} images.") |
| 251 | + |
| 252 | + calibLoader = DataLoader( |
| 253 | + Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True |
| 254 | + ) |
| 255 | + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) |
| 256 | + |
| 257 | + # FBRANCASI: I'm on mac, so mps for me |
| 258 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 259 | + device = torch.device("mps" if torch.backends.mps.is_available() else device) |
| 260 | + print(f"Using device: {device}") |
| 261 | + |
| 262 | + originalModel = torchvision.models.vit_b_32( |
| 263 | + weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1 |
| 264 | + ) |
| 265 | + originalModel = originalModel.eval().to(device) |
| 266 | + print("Original ViT-B/32 loaded.") |
| 267 | + |
| 268 | + print("Evaluating original model...") |
| 269 | + originalTop1, originalTop5 = evaluateModel( |
| 270 | + originalModel, valLoader, device, "Original ViT-B/32" |
| 271 | + ) |
| 272 | + |
| 273 | + print("Preparing and quantizing ViT-B/32...") |
| 274 | + FQModel = prepareFQVitB32() |
| 275 | + |
| 276 | + print("Calibrating FQ model...") |
| 277 | + calibrateModel(FQModel, calibLoader) |
| 278 | + |
| 279 | + print("Evaluating FQ model...") |
| 280 | + # FBRANCASI: I'm on mac, mps doesn't work with brevitas |
| 281 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 282 | + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ViT-B/32") |
| 283 | + |
| 284 | + sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu") |
| 285 | + TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True) |
| 286 | + |
| 287 | + numParameters = sum(p.numel() for p in TQModel.parameters()) |
| 288 | + print(f"Number of parameters: {numParameters:,}") |
| 289 | + |
| 290 | + print("Evaluating TQ model...") |
| 291 | + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ViT-B/32") |
| 292 | + |
| 293 | + print("\nComparison Summary:") |
| 294 | + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") |
| 295 | + print("-" * 75) |
| 296 | + print(f"{'Original ViT-B/32':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") |
| 297 | + print(f"{'FQ ViT-B/32':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") |
| 298 | + print(f"{'TQ ViT-B/32':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") |
| 299 | + print( |
| 300 | + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" |
| 301 | + ) |
| 302 | + print( |
| 303 | + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" |
| 304 | + ) |
| 305 | + |
| 306 | + if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0: |
| 307 | + print( |
| 308 | + f"Warning: Large accuracy drop between FQ and TQ models. " |
| 309 | + f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, " |
| 310 | + f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%" |
| 311 | + ) |
0 commit comments