44#
55# Federico Brancasi <[email protected] > 66
7- import tarfile
8- import urllib .request
9- from pathlib import Path
107
118import brevitas .nn as qnn
129import pytest
1310import torch
1411import torch .nn as nn
15- import torchvision
16- import torchvision .transforms as transforms
17- from brevitas .graph .calibrate import calibration_mode
12+ import torchvision .models as models
1813from brevitas .graph .per_input import AdaptiveAvgPoolToAvgPool
1914from brevitas .graph .quantize import preprocess_for_quantize , quantize
2015from brevitas .quant import (
2318 Int32Bias ,
2419 Uint8ActPerTensorFloat ,
2520)
26- from torch .utils .data import DataLoader , Subset
27- from torchvision .datasets import ImageFolder
28- from tqdm import tqdm
2921
3022from DeepQuant import brevitasToTrueQuant
3123
3224
33- def evaluateModel (model , dataLoader , evalDevice , name = "Model" ):
34- model .eval ()
35- correctTop1 = 0
36- correctTop5 = 0
37- total = 0
38-
39- with torch .no_grad ():
40- for inputs , targets in tqdm (dataLoader , desc = f"Evaluating { name } " ):
41- isTQ = "TQ" in name
42-
43- if isTQ :
44- # FBRANCASI: Process different batches for the TQ model
45- for i in range (inputs .size (0 )):
46- singleInput = inputs [i : i + 1 ].to (evalDevice )
47- singleOutput = model (singleInput )
48-
49- _ , predicted = singleOutput .max (1 )
50- if predicted .item () == targets [i ].item ():
51- correctTop1 += 1
52-
53- _ , top5Pred = singleOutput .topk (5 , dim = 1 , largest = True , sorted = True )
54- if targets [i ].item () in top5Pred [0 ].cpu ().numpy ():
55- correctTop5 += 1
56-
57- total += 1
58- else :
59- inputs = inputs .to (evalDevice )
60- targets = targets .to (evalDevice )
61- output = model (inputs )
62-
63- _ , predicted = output .max (1 )
64- correctTop1 += (predicted == targets ).sum ().item ()
65-
66- _ , top5Pred = output .topk (5 , dim = 1 , largest = True , sorted = True )
67- for i in range (targets .size (0 )):
68- if targets [i ] in top5Pred [i ]:
69- correctTop5 += 1
70-
71- total += targets .size (0 )
72-
73- top1Accuracy = 100.0 * correctTop1 / total
74- top5Accuracy = 100.0 * correctTop5 / total
75-
76- print (
77- f"{ name } - Top-1 Accuracy: { top1Accuracy :.2f} % ({ correctTop1 } /{ total } ), "
78- f"Top-5 Accuracy: { top5Accuracy :.2f} %"
79- )
80-
81- return top1Accuracy , top5Accuracy
82-
83-
84- def calibrateModel (model , calibLoader ):
85- model .eval ()
86- with torch .no_grad (), calibration_mode (model ):
87- for inputs , _ in tqdm (calibLoader , desc = "Calibrating model" ):
88- inputs = inputs .to ("cpu" )
89- model (inputs )
90- print ("Calibration completed." )
91-
92-
93- def prepareFQResNet18 ():
25+ def prepareResnet18Model () -> nn .Module :
9426 """Prepare a fake-quantized (FQ) ResNet18 model."""
95- baseModel = torchvision .models .resnet18 (
96- weights = torchvision .models .ResNet18_Weights .IMAGENET1K_V1
97- )
98- baseModel = baseModel .eval ().to ("cpu" )
27+ baseModel = models .resnet18 (weights = models .ResNet18_Weights .DEFAULT )
28+
29+ baseModel = baseModel .eval ()
9930
10031 computeLayerMap = {
10132 nn .Conv2d : (
@@ -126,16 +57,7 @@ def prepareFQResNet18():
12657 ),
12758 }
12859
129- quantActMap = {
130- nn .ReLU : (
131- qnn .QuantReLU ,
132- {
133- "act_quant" : Uint8ActPerTensorFloat ,
134- "return_quant_tensor" : True ,
135- "bit_width" : 8 ,
136- },
137- ),
138- }
60+ quantActMap = {}
13961
14062 quantIdentityMap = {
14163 "signed" : (
@@ -156,133 +78,25 @@ def prepareFQResNet18():
15678 ),
15779 }
15880
159- dummyInput = torch .ones (1 , 3 , 224 , 224 ).to ("cpu" )
160-
161- print ("Preprocessing model for quantization..." )
16281 baseModel = preprocess_for_quantize (
16382 baseModel , equalize_iters = 20 , equalize_scale_computation = "range"
16483 )
84+ baseModel = AdaptiveAvgPoolToAvgPool ().apply (baseModel , torch .ones (1 , 3 , 224 , 224 ))
16585
166- print ("Converting AdaptiveAvgPool to AvgPool..." )
167- baseModel = AdaptiveAvgPoolToAvgPool ().apply (baseModel , dummyInput )
168-
169- print ("Quantizing model..." )
170- FQModel = quantize (
86+ quantizedResnet = quantize (
17187 graph_model = baseModel ,
17288 compute_layer_map = computeLayerMap ,
17389 quant_act_map = quantActMap ,
17490 quant_identity_map = quantIdentityMap ,
17591 )
17692
177- return FQModel
93+ return quantizedResnet
17894
17995
18096@pytest .mark .ModelTests
18197def deepQuantTestResnet18 () -> None :
182- HOME = Path .home ()
183- BASE = HOME / "Documents" / "ImagenetV2"
184- TAR_URL = (
185- "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/"
186- "imagenetv2-matched-frequency.tar.gz"
187- )
188- TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz"
189- EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val"
190-
191- if not TAR_PATH .exists ():
192- BASE .mkdir (parents = True , exist_ok = True )
193- print (f"Downloading ImageNetV2 from { TAR_URL } ..." )
194- urllib .request .urlretrieve (TAR_URL , TAR_PATH )
195-
196- if not EXTRACT_DIR .exists ():
197- print (f"Extracting to { EXTRACT_DIR } ..." )
198- with tarfile .open (TAR_PATH , "r:*" ) as tar :
199- for member in tqdm (tar .getmembers (), desc = "Extracting files" ):
200- tar .extract (member , BASE )
201- print ("Extraction completed." )
202-
203- transformsVal = transforms .Compose (
204- [
205- transforms .Resize (256 ),
206- transforms .CenterCrop (224 ),
207- transforms .ToTensor (),
208- transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
209- ]
210- )
211-
212- dataset = ImageFolder (root = str (EXTRACT_DIR ), transform = transformsVal )
213- dataset .classes = sorted (dataset .classes , key = lambda x : int (x ))
214- dataset .class_to_idx = {cls : i for i , cls in enumerate (dataset .classes )}
215-
216- newSamples = []
217- for path , _ in dataset .samples :
218- clsName = Path (path ).parent .name
219- newLabel = dataset .class_to_idx [clsName ]
220- newSamples .append ((path , newLabel ))
221- dataset .samples = newSamples
222- dataset .targets = [s [1 ] for s in newSamples ]
223-
224- # FBRANCASI: Optional, reduce number of example for faster validation
225- DATASET_LIMIT = 256
226- dataset = Subset (dataset , list (range (DATASET_LIMIT )))
227- print (f"Validation dataset size set to { len (dataset )} images." )
228-
229- calibLoader = DataLoader (
230- Subset (dataset , list (range (256 ))), batch_size = 32 , shuffle = False , pin_memory = True
231- )
232- valLoader = DataLoader (dataset , batch_size = 32 , shuffle = False , pin_memory = True )
233-
234- # FBRANCASI: I'm on mac, so mps for me
235- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
236- device = torch .device ("mps" if torch .backends .mps .is_available () else device )
237- print (f"Using device: { device } " )
238-
239- originalModel = torchvision .models .resnet18 (
240- weights = torchvision .models .ResNet18_Weights .IMAGENET1K_V1
241- )
242- originalModel = originalModel .eval ().to (device )
243- print ("Original ResNet18 loaded." )
244-
245- print ("Evaluating original model..." )
246- originalTop1 , originalTop5 = evaluateModel (
247- originalModel , valLoader , device , "Original ResNet18"
248- )
249-
250- print ("Preparing and quantizing ResNet18..." )
251- FQModel = prepareFQResNet18 ()
252-
253- print ("Calibrating FQ model..." )
254- calibrateModel (FQModel , calibLoader )
255-
256- print ("Evaluating FQ model..." )
257- # FBRANCASI: I'm on mac, mps doesn't work with brevitas
258- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
259- FQTop1 , FQTop5 = evaluateModel (FQModel , valLoader , device , "FQ ResNet18" )
260-
261- sampleInputImg = torch .randn (1 , 3 , 224 , 224 ).to ("cpu" )
262- TQModel = brevitasToTrueQuant (FQModel , sampleInputImg , debug = True )
263-
264- numParameters = sum (p .numel () for p in TQModel .parameters ())
265- print (f"Number of parameters: { numParameters :,} " )
266-
267- print ("Evaluating TQ model..." )
268- TQTop1 , TQTop5 = evaluateModel (TQModel , valLoader , device , "TQ ResNet18" )
269-
270- print ("\n Comparison Summary:" )
271- print (f"{ 'Model' :<25} { 'Top-1 Accuracy' :<25} { 'Top-5 Accuracy' :<25} " )
272- print ("-" * 75 )
273- print (f"{ 'Original ResNet18' :<25} { originalTop1 :<24.2f} { originalTop5 :<24.2f} " )
274- print (f"{ 'FQ ResNet18' :<25} { FQTop1 :<24.2f} { FQTop5 :<24.2f} " )
275- print (f"{ 'TQ ResNet18' :<25} { TQTop1 :<24.2f} { TQTop5 :<24.2f} " )
276- print (
277- f"{ 'FQ Drop' :<25} { originalTop1 - FQTop1 :<24.2f} { originalTop5 - FQTop5 :<24.2f} "
278- )
279- print (
280- f"{ 'TQ Drop' :<25} { originalTop1 - TQTop1 :<24.2f} { originalTop5 - TQTop5 :<24.2f} "
281- )
28298
283- if abs (FQTop1 - TQTop1 ) > 5.0 or abs (FQTop5 - TQTop5 ) > 5.0 :
284- print (
285- f"Warning: Large accuracy drop between FQ and TQ models. "
286- f"Top-1 difference: { abs (FQTop1 - TQTop1 ):.2f} %, "
287- f"Top-5 difference: { abs (FQTop5 - TQTop5 ):.2f} %"
288- )
99+ torch .manual_seed (42 )
100+ quantizedModel = prepareResnet18Model ()
101+ sampleInput = torch .randn (1 , 3 , 224 , 224 )
102+ brevitasToTrueQuant (quantizedModel , sampleInput , debug = True )
0 commit comments