Skip to content

Commit 2657957

Browse files
Update TestCCT
1 parent d580088 commit 2657957

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dist/
2525
*.gz
2626
*-ubyte
2727
*.pt
28+
.c*
2829
*.onnx
2930
*.npz
3031
onnx/*

Tests/TestCCTPretrained.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def injectCustomForwards(
8282

8383
def evaluateModel(model, dataLoader, evalDevice, name="Model"):
8484
model.eval()
85-
correct = 0
85+
correctTop1 = 0
86+
correctTop5 = 0
8687
total = 0
8788

8889
with torch.no_grad():
@@ -97,7 +98,11 @@ def evaluateModel(model, dataLoader, evalDevice, name="Model"):
9798

9899
_, predicted = singleOutput.max(1)
99100
if predicted.item() == targets[i].item():
100-
correct += 1
101+
correctTop1 += 1
102+
103+
_, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True)
104+
if targets[i].item() in top5Pred[0].cpu().numpy():
105+
correctTop5 += 1
101106

102107
total += 1
103108
else:
@@ -106,12 +111,24 @@ def evaluateModel(model, dataLoader, evalDevice, name="Model"):
106111
output = model(inputs)
107112

108113
_, predicted = output.max(1)
109-
correct += (predicted == targets).sum().item()
114+
correctTop1 += (predicted == targets).sum().item()
115+
116+
_, top5Pred = output.topk(5, dim=1, largest=True, sorted=True)
117+
for i in range(targets.size(0)):
118+
if targets[i] in top5Pred[i]:
119+
correctTop5 += 1
120+
110121
total += targets.size(0)
111122

112-
accuracy = 100.0 * correct / total
113-
print(f"{name} - Accuracy: {accuracy:.2f}% ({correct}/{total})")
114-
return accuracy
123+
top1Accuracy = 100.0 * correctTop1 / total
124+
top5Accuracy = 100.0 * correctTop5 / total
125+
126+
print(
127+
f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), "
128+
f"Top-5 Accuracy: {top5Accuracy:.2f}%"
129+
)
130+
131+
return top1Accuracy, top5Accuracy
115132

116133

117134
def calibrateModel(model, calibLoader):
@@ -302,7 +319,7 @@ def deepQuantTestCCT():
302319
print("Original CCT-2 loaded from checkpoint.")
303320

304321
print("Evaluating original model...")
305-
originalAccuracy = evaluateModel(originalModel, valLoader, device, "Original CCT-2")
322+
originalTop1, originalTop5 = evaluateModel(originalModel, valLoader, device, "Original CCT-2")
306323

307324
print("Preparing and quantizing CCT-2...")
308325
FQModel = prepareFQCCT(originalModel.to("cpu"))
@@ -313,7 +330,7 @@ def deepQuantTestCCT():
313330
print("Evaluating FQ model...")
314331
# FBRANCASI: Use CPU for brevitas models
315332
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
316-
FQAccuracy = evaluateModel(FQModel, valLoader, device, "FQ CCT-2")
333+
FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ CCT-2")
317334

318335
sampleInput = torch.randn(1, 3, 32, 32).to("cpu")
319336

@@ -346,19 +363,19 @@ def deepQuantTestCCT():
346363
print(f"Number of parameters: {numParameters:,}")
347364

348365
print("Evaluating TQ model...")
349-
TQAccuracy = evaluateModel(TQModel, valLoader, device, "TQ CCT-2")
366+
TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ CCT-2")
350367

351368
print("\nComparison Summary:")
352-
print(f"{'Model':<25} {'Accuracy':<25}")
353-
print("-" * 50)
354-
print(f"{'Original CCT-2':<25} {originalAccuracy:<24.2f}")
355-
print(f"{'FQ CCT-2':<25} {FQAccuracy:<24.2f}")
356-
print(f"{'TQ CCT-2':<25} {TQAccuracy:<24.2f}")
357-
print(f"{'FQ Drop':<25} {originalAccuracy - FQAccuracy:<24.2f}")
358-
print(f"{'TQ Drop':<25} {originalAccuracy - TQAccuracy:<24.2f}")
359-
360-
if abs(FQAccuracy - TQAccuracy) > 5.0:
369+
print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}")
370+
print("-" * 75)
371+
print(f"{'Original CCT-2':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}")
372+
print(f"{'FQ CCT-2':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}")
373+
print(f"{'TQ CCT-2':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}")
374+
print(f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}")
375+
print(f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}")
376+
377+
if abs(FQTop1 - TQTop1) > 5.0:
361378
print(
362379
f"Warning: Large accuracy drop between FQ and TQ models. "
363-
f"Difference: {abs(FQAccuracy - TQAccuracy):.2f}%"
380+
f"Difference: {abs(FQTop1 - TQTop1):.2f}%"
364381
)

0 commit comments

Comments
 (0)