Skip to content

Commit 856f249

Browse files
Add ViTB32 Test
1 parent b2c6fe2 commit 856f249

File tree

4 files changed

+328
-5
lines changed

4 files changed

+328
-5
lines changed

DeepQuant/CustomForwards/MultiHeadAttention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def mhaForwardBatchFirst(
110110
attn_output = _mhaForwardImpl(
111111
self, query, key, value, need_transpose_in=True, need_transpose_out=True
112112
)
113-
# PyTorch always returns a tuple, even when need_weights=False
114113
return (attn_output, None)
115114

116115

@@ -126,7 +125,6 @@ def mhaForwardSeqFirst(
126125
attn_output = _mhaForwardImpl(
127126
self, query, key, value, need_transpose_in=False, need_transpose_out=False
128127
)
129-
# PyTorch always returns a tuple, even when need_weights=False
130128
return (attn_output, None)
131129

132130

DeepQuant/Pipeline/Injection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def injectCustomForwards(
5151
output = fxModel(exampleInput)
5252

5353
if checkEquivalence:
54-
# Handle case where output might be a tuple (e.g., from MHA)
5554
outputToCompare = output[0] if isinstance(output, tuple) else output
5655
if torch.allclose(referenceOutput, outputToCompare, atol=1e-5):
5756
if debug:

Tests/TestVitB32.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
17
import brevitas.nn as qnn
28
import pytest
39
import torch
@@ -113,12 +119,21 @@ def prepare_vit_b_32(model: nn.Module) -> nn.Module:
113119

114120
@pytest.mark.ModelTests
115121
def deepQuantTestViT():
122+
torch.manual_seed(42)
123+
sampleInput = torch.randn(1, 3, 224, 224)
116124

117125
vit_model = models.vit_b_32(weights=models.ViT_B_32_Weights.IMAGENET1K_V1)
118-
119126
vit_model.eval()
120127

128+
print(f"\nTesting ViT-B/32 model with input shape: {sampleInput.shape}")
129+
121130
quantized_vit = prepare_vit_b_32(vit_model)
122131

123-
sampleInput = torch.randn(1, 3, 224, 224)
132+
with torch.no_grad():
133+
output = quantized_vit(sampleInput)
134+
if isinstance(output, tuple):
135+
output = output[0]
136+
print(f"Output shape: {output.shape}")
137+
print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]")
138+
124139
brevitasToTrueQuant(quantized_vit, sampleInput, debug=True, checkEquivalence=False)

Tests/TestVitB32Pretrained.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
import tarfile
8+
import urllib.request
9+
from pathlib import Path
10+
11+
import brevitas.nn as qnn
12+
import pytest
13+
import torch
14+
import torch.nn as nn
15+
import torchvision
16+
import torchvision.transforms as transforms
17+
from brevitas.graph.calibrate import calibration_mode
18+
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
19+
from brevitas.graph.quantize import preprocess_for_quantize, quantize
20+
from brevitas.quant import (
21+
Int8ActPerTensorFloat,
22+
Int8WeightPerTensorFloat,
23+
Int32Bias,
24+
Uint8ActPerTensorFloat,
25+
)
26+
from torch.utils.data import DataLoader, Subset
27+
from torchvision.datasets import ImageFolder
28+
from tqdm import tqdm
29+
30+
from DeepQuant import brevitasToTrueQuant
31+
32+
33+
def evaluateModel(model, dataLoader, evalDevice, name="Model"):
34+
model.eval()
35+
correctTop1 = 0
36+
correctTop5 = 0
37+
total = 0
38+
39+
with torch.no_grad():
40+
for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"):
41+
isTQ = "TQ" in name
42+
43+
if isTQ:
44+
# FBRANCASI: Process different batches for the TQ model
45+
for i in range(inputs.size(0)):
46+
singleInput = inputs[i : i + 1].to(evalDevice)
47+
singleOutput = model(singleInput)
48+
if isinstance(singleOutput, tuple):
49+
singleOutput = singleOutput[0]
50+
51+
_, predicted = singleOutput.max(1)
52+
if predicted.item() == targets[i].item():
53+
correctTop1 += 1
54+
55+
_, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True)
56+
if targets[i].item() in top5Pred[0].cpu().numpy():
57+
correctTop5 += 1
58+
59+
total += 1
60+
else:
61+
inputs = inputs.to(evalDevice)
62+
targets = targets.to(evalDevice)
63+
output = model(inputs)
64+
if isinstance(output, tuple):
65+
output = output[0]
66+
67+
_, predicted = output.max(1)
68+
correctTop1 += (predicted == targets).sum().item()
69+
70+
_, top5Pred = output.topk(5, dim=1, largest=True, sorted=True)
71+
for i in range(targets.size(0)):
72+
if targets[i] in top5Pred[i]:
73+
correctTop5 += 1
74+
75+
total += targets.size(0)
76+
77+
top1Accuracy = 100.0 * correctTop1 / total
78+
top5Accuracy = 100.0 * correctTop5 / total
79+
80+
print(
81+
f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), "
82+
f"Top-5 Accuracy: {top5Accuracy:.2f}%"
83+
)
84+
85+
return top1Accuracy, top5Accuracy
86+
87+
88+
def calibrateModel(model, calibLoader):
89+
model.eval()
90+
with torch.no_grad(), calibration_mode(model):
91+
for inputs, _ in tqdm(calibLoader, desc="Calibrating model"):
92+
inputs = inputs.to("cpu")
93+
output = model(inputs)
94+
if isinstance(output, tuple):
95+
output = output[0]
96+
print("Calibration completed.")
97+
98+
99+
def prepareFQVitB32():
100+
"""Prepare a fake-quantized (FQ) ViT-B/32 model."""
101+
baseModel = torchvision.models.vit_b_32(
102+
weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1
103+
)
104+
baseModel = baseModel.eval().to("cpu")
105+
106+
computeLayerMap = {
107+
nn.Conv2d: (
108+
qnn.QuantConv2d,
109+
{
110+
"input_quant": Int8ActPerTensorFloat,
111+
"weight_quant": Int8WeightPerTensorFloat,
112+
"output_quant": Int8ActPerTensorFloat,
113+
"bias_quant": Int32Bias,
114+
"bias": True,
115+
"return_quant_tensor": True,
116+
"output_bit_width": 8,
117+
"weight_bit_width": 8,
118+
},
119+
),
120+
nn.MultiheadAttention: (
121+
qnn.QuantMultiheadAttention,
122+
{
123+
"in_proj_input_quant": Int8ActPerTensorFloat,
124+
"in_proj_weight_quant": Int8WeightPerTensorFloat,
125+
"in_proj_bias_quant": Int32Bias,
126+
"attn_output_weights_quant": Uint8ActPerTensorFloat,
127+
"q_scaled_quant": Int8ActPerTensorFloat,
128+
"k_transposed_quant": Int8ActPerTensorFloat,
129+
"v_quant": Int8ActPerTensorFloat,
130+
"out_proj_input_quant": Int8ActPerTensorFloat,
131+
"out_proj_weight_quant": Int8WeightPerTensorFloat,
132+
"out_proj_bias_quant": Int32Bias,
133+
"out_proj_output_quant": Int8ActPerTensorFloat,
134+
"return_quant_tensor": True,
135+
},
136+
),
137+
nn.Linear: (
138+
qnn.QuantLinear,
139+
{
140+
"input_quant": Int8ActPerTensorFloat,
141+
"weight_quant": Int8WeightPerTensorFloat,
142+
"output_quant": Int8ActPerTensorFloat,
143+
"bias_quant": Int32Bias,
144+
"bias": True,
145+
"return_quant_tensor": True,
146+
"output_bit_width": 8,
147+
"weight_bit_width": 8,
148+
},
149+
),
150+
}
151+
152+
quantActMap = {
153+
nn.GELU: (
154+
qnn.QuantReLU, # FBRANCASI: Approximating GELU with QuantReLU
155+
{
156+
"act_quant": Uint8ActPerTensorFloat,
157+
"return_quant_tensor": True,
158+
"bit_width": 8,
159+
},
160+
),
161+
}
162+
163+
quantIdentityMap = {
164+
"signed": (
165+
qnn.QuantIdentity,
166+
{
167+
"act_quant": Int8ActPerTensorFloat,
168+
"return_quant_tensor": True,
169+
"bit_width": 8,
170+
},
171+
),
172+
"unsigned": (
173+
qnn.QuantIdentity,
174+
{
175+
"act_quant": Uint8ActPerTensorFloat,
176+
"return_quant_tensor": True,
177+
"bit_width": 8,
178+
},
179+
),
180+
}
181+
182+
dummyInput = torch.ones(1, 3, 224, 224).to("cpu")
183+
184+
print("Preprocessing model for quantization...")
185+
baseModel = preprocess_for_quantize(
186+
baseModel, equalize_iters=20, equalize_scale_computation="range"
187+
)
188+
189+
print("Converting AdaptiveAvgPool to AvgPool...")
190+
baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput)
191+
192+
print("Quantizing model...")
193+
FQModel = quantize(
194+
graph_model=baseModel,
195+
compute_layer_map=computeLayerMap,
196+
quant_act_map=quantActMap,
197+
quant_identity_map=quantIdentityMap,
198+
)
199+
200+
return FQModel
201+
202+
203+
@pytest.mark.ModelTests
204+
def deepQuantTestVitB32Pretrained() -> None:
205+
HOME = Path.home()
206+
BASE = HOME / "Documents" / "ImagenetV2"
207+
TAR_URL = (
208+
"https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/"
209+
"imagenetv2-matched-frequency.tar.gz"
210+
)
211+
TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz"
212+
EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val"
213+
214+
if not TAR_PATH.exists():
215+
BASE.mkdir(parents=True, exist_ok=True)
216+
print(f"Downloading ImageNetV2 from {TAR_URL}...")
217+
urllib.request.urlretrieve(TAR_URL, TAR_PATH)
218+
219+
if not EXTRACT_DIR.exists():
220+
print(f"Extracting to {EXTRACT_DIR}...")
221+
with tarfile.open(TAR_PATH, "r:*") as tar:
222+
for member in tqdm(tar.getmembers(), desc="Extracting files"):
223+
tar.extract(member, BASE)
224+
print("Extraction completed.")
225+
226+
transformsVal = transforms.Compose(
227+
[
228+
transforms.Resize(256),
229+
transforms.CenterCrop(224),
230+
transforms.ToTensor(),
231+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
232+
]
233+
)
234+
235+
dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal)
236+
dataset.classes = sorted(dataset.classes, key=lambda x: int(x))
237+
dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)}
238+
239+
newSamples = []
240+
for path, _ in dataset.samples:
241+
clsName = Path(path).parent.name
242+
newLabel = dataset.class_to_idx[clsName]
243+
newSamples.append((path, newLabel))
244+
dataset.samples = newSamples
245+
dataset.targets = [s[1] for s in newSamples]
246+
247+
# FBRANCASI: Optional, reduce number of example for faster validation
248+
DATASET_LIMIT = 256
249+
dataset = Subset(dataset, list(range(DATASET_LIMIT)))
250+
print(f"Validation dataset size set to {len(dataset)} images.")
251+
252+
calibLoader = DataLoader(
253+
Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True
254+
)
255+
valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True)
256+
257+
# FBRANCASI: I'm on mac, so mps for me
258+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259+
device = torch.device("mps" if torch.backends.mps.is_available() else device)
260+
print(f"Using device: {device}")
261+
262+
originalModel = torchvision.models.vit_b_32(
263+
weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1
264+
)
265+
originalModel = originalModel.eval().to(device)
266+
print("Original ViT-B/32 loaded.")
267+
268+
print("Evaluating original model...")
269+
originalTop1, originalTop5 = evaluateModel(
270+
originalModel, valLoader, device, "Original ViT-B/32"
271+
)
272+
273+
print("Preparing and quantizing ViT-B/32...")
274+
FQModel = prepareFQVitB32()
275+
276+
print("Calibrating FQ model...")
277+
calibrateModel(FQModel, calibLoader)
278+
279+
print("Evaluating FQ model...")
280+
# FBRANCASI: I'm on mac, mps doesn't work with brevitas
281+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
282+
FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ViT-B/32")
283+
284+
sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu")
285+
TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True)
286+
287+
numParameters = sum(p.numel() for p in TQModel.parameters())
288+
print(f"Number of parameters: {numParameters:,}")
289+
290+
print("Evaluating TQ model...")
291+
TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ViT-B/32")
292+
293+
print("\nComparison Summary:")
294+
print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}")
295+
print("-" * 75)
296+
print(f"{'Original ViT-B/32':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}")
297+
print(f"{'FQ ViT-B/32':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}")
298+
print(f"{'TQ ViT-B/32':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}")
299+
print(
300+
f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}"
301+
)
302+
print(
303+
f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}"
304+
)
305+
306+
if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0:
307+
print(
308+
f"Warning: Large accuracy drop between FQ and TQ models. "
309+
f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, "
310+
f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%"
311+
)

0 commit comments

Comments
 (0)