@@ -82,7 +82,8 @@ def injectCustomForwards(
8282
8383def evaluateModel (model , dataLoader , evalDevice , name = "Model" ):
8484 model .eval ()
85- correct = 0
85+ correctTop1 = 0
86+ correctTop5 = 0
8687 total = 0
8788
8889 with torch .no_grad ():
@@ -97,7 +98,11 @@ def evaluateModel(model, dataLoader, evalDevice, name="Model"):
9798
9899 _ , predicted = singleOutput .max (1 )
99100 if predicted .item () == targets [i ].item ():
100- correct += 1
101+ correctTop1 += 1
102+
103+ _ , top5Pred = singleOutput .topk (5 , dim = 1 , largest = True , sorted = True )
104+ if targets [i ].item () in top5Pred [0 ].cpu ().numpy ():
105+ correctTop5 += 1
101106
102107 total += 1
103108 else :
@@ -106,12 +111,24 @@ def evaluateModel(model, dataLoader, evalDevice, name="Model"):
106111 output = model (inputs )
107112
108113 _ , predicted = output .max (1 )
109- correct += (predicted == targets ).sum ().item ()
114+ correctTop1 += (predicted == targets ).sum ().item ()
115+
116+ _ , top5Pred = output .topk (5 , dim = 1 , largest = True , sorted = True )
117+ for i in range (targets .size (0 )):
118+ if targets [i ] in top5Pred [i ]:
119+ correctTop5 += 1
120+
110121 total += targets .size (0 )
111122
112- accuracy = 100.0 * correct / total
113- print (f"{ name } - Accuracy: { accuracy :.2f} % ({ correct } /{ total } )" )
114- return accuracy
123+ top1Accuracy = 100.0 * correctTop1 / total
124+ top5Accuracy = 100.0 * correctTop5 / total
125+
126+ print (
127+ f"{ name } - Top-1 Accuracy: { top1Accuracy :.2f} % ({ correctTop1 } /{ total } ), "
128+ f"Top-5 Accuracy: { top5Accuracy :.2f} %"
129+ )
130+
131+ return top1Accuracy , top5Accuracy
115132
116133
117134def calibrateModel (model , calibLoader ):
@@ -302,7 +319,7 @@ def deepQuantTestCCT():
302319 print ("Original CCT-2 loaded from checkpoint." )
303320
304321 print ("Evaluating original model..." )
305- originalAccuracy = evaluateModel (originalModel , valLoader , device , "Original CCT-2" )
322+ originalTop1 , originalTop5 = evaluateModel (originalModel , valLoader , device , "Original CCT-2" )
306323
307324 print ("Preparing and quantizing CCT-2..." )
308325 FQModel = prepareFQCCT (originalModel .to ("cpu" ))
@@ -313,7 +330,7 @@ def deepQuantTestCCT():
313330 print ("Evaluating FQ model..." )
314331 # FBRANCASI: Use CPU for brevitas models
315332 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
316- FQAccuracy = evaluateModel (FQModel , valLoader , device , "FQ CCT-2" )
333+ FQTop1 , FQTop5 = evaluateModel (FQModel , valLoader , device , "FQ CCT-2" )
317334
318335 sampleInput = torch .randn (1 , 3 , 32 , 32 ).to ("cpu" )
319336
@@ -346,19 +363,19 @@ def deepQuantTestCCT():
346363 print (f"Number of parameters: { numParameters :,} " )
347364
348365 print ("Evaluating TQ model..." )
349- TQAccuracy = evaluateModel (TQModel , valLoader , device , "TQ CCT-2" )
366+ TQTop1 , TQTop5 = evaluateModel (TQModel , valLoader , device , "TQ CCT-2" )
350367
351368 print ("\n Comparison Summary:" )
352- print (f"{ 'Model' :<25} { 'Accuracy' :<25} " )
353- print ("-" * 50 )
354- print (f"{ 'Original CCT-2' :<25} { originalAccuracy :<24.2f} " )
355- print (f"{ 'FQ CCT-2' :<25} { FQAccuracy :<24.2f} " )
356- print (f"{ 'TQ CCT-2' :<25} { TQAccuracy :<24.2f} " )
357- print (f"{ 'FQ Drop' :<25} { originalAccuracy - FQAccuracy :<24.2f} " )
358- print (f"{ 'TQ Drop' :<25} { originalAccuracy - TQAccuracy :<24.2f} " )
359-
360- if abs (FQAccuracy - TQAccuracy ) > 5.0 :
369+ print (f"{ 'Model' :<25} { 'Top-1 Accuracy' :<25 } { 'Top-5 Accuracy' :<25} " )
370+ print ("-" * 75 )
371+ print (f"{ 'Original CCT-2' :<25} { originalTop1 :<24.2f } { originalTop5 :<24.2f} " )
372+ print (f"{ 'FQ CCT-2' :<25} { FQTop1 :<24.2f } { FQTop5 :<24.2f} " )
373+ print (f"{ 'TQ CCT-2' :<25} { TQTop1 :<24.2f } { TQTop5 :<24.2f} " )
374+ print (f"{ 'FQ Drop' :<25} { originalTop1 - FQTop1 :<24.2f } { originalTop5 - FQTop5 :<24.2f} " )
375+ print (f"{ 'TQ Drop' :<25} { originalTop1 - TQTop1 :<24.2f } { originalTop5 - TQTop5 :<24.2f} " )
376+
377+ if abs (FQTop1 - TQTop1 ) > 5.0 :
361378 print (
362379 f"Warning: Large accuracy drop between FQ and TQ models. "
363- f"Difference: { abs (FQAccuracy - TQAccuracy ):.2f} %"
380+ f"Difference: { abs (FQTop1 - TQTop1 ):.2f} %"
364381 )
0 commit comments