Skip to content

Commit 531fa57

Browse files
Modify CCT Test
1 parent 9e489b1 commit 531fa57

File tree

3 files changed

+296
-0
lines changed

3 files changed

+296
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ dist/
3030
*.npz
3131
onnx/*
3232
Dataset/*
33+
Data/*

Tests/TestCCT.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def prepareCCT(model) -> nn.Module:
8484

8585
quant_name = f"{node.name}_reshape_fix"
8686
model.add_module(quant_name, quant_identity)
87+
# mark this QuantIdentity as “reshape fix”
88+
quant_identity._is_reshape_fix = True
8789

8890
with model.graph.inserting_after(node):
8991
quant_node = model.graph.call_module(quant_name, args=(node,))
@@ -178,3 +180,7 @@ def deepQuantTestCCT():
178180
output = quantizedModel(sampleInput)
179181
print(f"Output shape: {output.shape}")
180182
print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]")
183+
184+
from DeepQuant import brevitasToTrueQuant
185+
186+
brevitasToTrueQuant(quantizedModel, sampleInput, debug=True)

Tests/TestCCTPretrained.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
import brevitas.nn as qnn
8+
import pytest
9+
import torch
10+
import torch.nn as nn
11+
import torchvision
12+
import torchvision.transforms as transforms
13+
from brevitas.fx.brevitas_tracer import symbolic_trace
14+
from brevitas.graph.calibrate import calibration_mode
15+
from brevitas.graph.quantize import preprocess_for_quantize, quantize
16+
from brevitas.graph.utils import replace_all_uses_except
17+
from brevitas.quant import (
18+
Int8ActPerTensorFloat,
19+
Int8WeightPerTensorFloat,
20+
Int32Bias,
21+
Uint8ActPerTensorFloat,
22+
)
23+
from torch.utils.data import DataLoader, Subset
24+
from tqdm import tqdm
25+
26+
from DeepQuant import brevitasToTrueQuant
27+
from Tests.Models.CCT import cct_2_3x2_32
28+
29+
30+
def evaluateModel(model, dataLoader, evalDevice, name="Model"):
31+
model.eval()
32+
correct = 0
33+
total = 0
34+
35+
with torch.no_grad():
36+
for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"):
37+
isTQ = "TQ" in name
38+
39+
if isTQ:
40+
# FBRANCASI: Process different batches for the TQ model
41+
for i in range(inputs.size(0)):
42+
singleInput = inputs[i : i + 1].to(evalDevice)
43+
singleOutput = model(singleInput)
44+
45+
_, predicted = singleOutput.max(1)
46+
if predicted.item() == targets[i].item():
47+
correct += 1
48+
49+
total += 1
50+
else:
51+
inputs = inputs.to(evalDevice)
52+
targets = targets.to(evalDevice)
53+
output = model(inputs)
54+
55+
_, predicted = output.max(1)
56+
correct += (predicted == targets).sum().item()
57+
total += targets.size(0)
58+
59+
accuracy = 100.0 * correct / total
60+
print(f"{name} - Accuracy: {accuracy:.2f}% ({correct}/{total})")
61+
return accuracy
62+
63+
64+
def calibrateModel(model, calibLoader):
65+
model.eval()
66+
with torch.no_grad(), calibration_mode(model):
67+
for inputs, _ in tqdm(calibLoader, desc="Calibrating model"):
68+
inputs = inputs.to("cpu")
69+
model(inputs)
70+
print("Calibration completed.")
71+
72+
73+
def prepareFQCCT(model) -> nn.Module:
74+
"""
75+
Prepare a quantized CCT model for testing with export support.
76+
"""
77+
78+
if not hasattr(model, "graph"):
79+
model = symbolic_trace(model)
80+
81+
print("=== FIXING QUANTIZATION ISSUES ===")
82+
83+
transpose_fixes = []
84+
qkv_fixes = []
85+
86+
# FBRANCASI: Fix 1, Find transpose -> add patterns
87+
for node in model.graph.nodes:
88+
if node.op == "call_method" and node.target == "transpose":
89+
for user in node.users:
90+
if (
91+
"add" in user.name
92+
or user.target in [torch.add]
93+
or (user.op == "call_method" and user.target in ["add", "add_"])
94+
):
95+
transpose_fixes.append((node, user))
96+
break
97+
98+
# FBRANCASI: Fix 2, Find QKV -> reshape patterns
99+
for node in model.graph.nodes:
100+
if node.op == "call_module" and "qkv" in node.target:
101+
for user in node.users:
102+
if user.op == "call_method" and user.target == "reshape":
103+
qkv_fixes.append((node, user))
104+
break
105+
106+
# FBRANCASI: Apply transpose fixes
107+
print(f"\nApplying {len(transpose_fixes)} transpose fixes...")
108+
for node, user in transpose_fixes:
109+
print(f" Fixing: {node.name} -> {user.name}")
110+
111+
quant_identity = qnn.QuantIdentity(
112+
act_quant=Int8ActPerTensorFloat, return_quant_tensor=True
113+
)
114+
115+
quant_name = f"{node.name}_quant_fix"
116+
model.add_module(quant_name, quant_identity)
117+
118+
with model.graph.inserting_after(node):
119+
quant_node = model.graph.call_module(quant_name, args=(node,))
120+
121+
# Replace uses
122+
replace_all_uses_except(node, quant_node, [quant_node])
123+
124+
# FBRANCASI: Apply QKV fixes
125+
print(f"\nApplying {len(qkv_fixes)} QKV fixes...")
126+
for node, reshape_user in qkv_fixes:
127+
print(f" Fixing: {node.name} -> {reshape_user.name}")
128+
129+
quant_identity = qnn.QuantIdentity(
130+
act_quant=Int8ActPerTensorFloat,
131+
return_quant_tensor=False, # FBRANCASI: return regular tensor for reshape
132+
)
133+
134+
quant_name = f"{node.name}_reshape_fix"
135+
model.add_module(quant_name, quant_identity)
136+
# mark this QuantIdentity as “reshape fix”
137+
quant_identity._is_reshape_fix = True
138+
139+
with model.graph.inserting_after(node):
140+
quant_node = model.graph.call_module(quant_name, args=(node,))
141+
142+
reshape_user.update_arg(0, quant_node)
143+
144+
model.recompile()
145+
model.graph.lint()
146+
147+
print("\n=== GRAPH MODIFICATION COMPLETE ===")
148+
149+
computeLayerMap = {
150+
nn.Conv2d: (
151+
qnn.QuantConv2d,
152+
{
153+
"input_quant": Int8ActPerTensorFloat,
154+
"weight_quant": Int8WeightPerTensorFloat,
155+
"output_quant": Int8ActPerTensorFloat,
156+
"bias_quant": Int32Bias,
157+
"bias": False,
158+
"return_quant_tensor": True,
159+
"output_bit_width": 8,
160+
"weight_bit_width": 4,
161+
},
162+
),
163+
nn.Linear: (
164+
qnn.QuantLinear,
165+
{
166+
"input_quant": Int8ActPerTensorFloat,
167+
"weight_quant": Int8WeightPerTensorFloat,
168+
"output_quant": Int8ActPerTensorFloat,
169+
"bias_quant": Int32Bias,
170+
"return_quant_tensor": True,
171+
"output_bit_width": 8,
172+
"weight_bit_width": 4,
173+
},
174+
),
175+
}
176+
177+
quantActMap = {}
178+
179+
quantIdentityMap = {
180+
"signed": (
181+
qnn.QuantIdentity,
182+
{
183+
"act_quant": Int8ActPerTensorFloat,
184+
"return_quant_tensor": True,
185+
"bit_width": 8,
186+
},
187+
),
188+
"unsigned": (
189+
qnn.QuantIdentity,
190+
{
191+
"act_quant": Uint8ActPerTensorFloat,
192+
"return_quant_tensor": True,
193+
"bit_width": 8,
194+
},
195+
),
196+
}
197+
198+
model = preprocess_for_quantize(
199+
model,
200+
equalize_iters=10,
201+
equalize_scale_computation="range",
202+
trace_model=False, # FBRANCASI: Already traced
203+
)
204+
205+
quantizedModel = quantize(
206+
graph_model=model,
207+
compute_layer_map=computeLayerMap,
208+
quant_act_map=quantActMap,
209+
quant_identity_map=quantIdentityMap,
210+
)
211+
212+
return quantizedModel
213+
214+
215+
@pytest.mark.ModelTests
216+
def deepQuantTestCCT():
217+
torch.manual_seed(42)
218+
219+
# FBRANCASI: Setup CIFAR-10 dataset
220+
transformsVal = transforms.Compose(
221+
[
222+
transforms.ToTensor(),
223+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
224+
]
225+
)
226+
227+
dataset = torchvision.datasets.CIFAR10(
228+
root="./data", train=False, download=True, transform=transformsVal
229+
)
230+
231+
DATASET_LIMIT = 256
232+
dataset = Subset(dataset, list(range(DATASET_LIMIT)))
233+
print(f"Validation dataset size set to {len(dataset)} images.")
234+
235+
calibLoader = DataLoader(
236+
Subset(dataset, list(range(128))), batch_size=32, shuffle=False, pin_memory=True
237+
)
238+
valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True)
239+
240+
# FBRANCASI: Device setup
241+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
242+
device = torch.device("mps" if torch.backends.mps.is_available() else device)
243+
print(f"Using device: {device}")
244+
245+
# FBRANCASI: Load original floating point model
246+
originalModel = cct_2_3x2_32()
247+
checkpointPath = "/Users/federicobrancasi/Documents/DeepQuant/Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth"
248+
checkpoint = torch.load(checkpointPath, map_location="cpu")
249+
originalModel.load_state_dict(checkpoint["model_state_dict"])
250+
originalModel = originalModel.eval().to(device)
251+
print("Original CCT-2 loaded from checkpoint.")
252+
253+
print("Evaluating original model...")
254+
originalAccuracy = evaluateModel(originalModel, valLoader, device, "Original CCT-2")
255+
256+
print("Preparing and quantizing CCT-2...")
257+
FQModel = prepareFQCCT(originalModel.to("cpu"))
258+
259+
print("Calibrating FQ model...")
260+
calibrateModel(FQModel, calibLoader)
261+
262+
print("Evaluating FQ model...")
263+
# FBRANCASI: Use CPU for brevitas models
264+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265+
FQAccuracy = evaluateModel(FQModel, valLoader, device, "FQ CCT-2")
266+
267+
sampleInput = torch.randn(1, 3, 32, 32).to("cpu")
268+
TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True)
269+
270+
numParameters = sum(p.numel() for p in TQModel.parameters())
271+
print(f"Number of parameters: {numParameters:,}")
272+
273+
print("Evaluating TQ model...")
274+
TQAccuracy = evaluateModel(TQModel, valLoader, device, "TQ CCT-2")
275+
276+
print("\nComparison Summary:")
277+
print(f"{'Model':<25} {'Accuracy':<25}")
278+
print("-" * 50)
279+
print(f"{'Original CCT-2':<25} {originalAccuracy:<24.2f}")
280+
print(f"{'FQ CCT-2':<25} {FQAccuracy:<24.2f}")
281+
print(f"{'TQ CCT-2':<25} {TQAccuracy:<24.2f}")
282+
print(f"{'FQ Drop':<25} {originalAccuracy - FQAccuracy:<24.2f}")
283+
print(f"{'TQ Drop':<25} {originalAccuracy - TQAccuracy:<24.2f}")
284+
285+
if abs(FQAccuracy - TQAccuracy) > 5.0:
286+
print(
287+
f"Warning: Large accuracy drop between FQ and TQ models. "
288+
f"Difference: {abs(FQAccuracy - TQAccuracy):.2f}%"
289+
)

0 commit comments

Comments
 (0)