1111
1212import os
1313import warnings
14+ from argparse import ArgumentParser
1415from pathlib import Path
1516
1617import openvino as ov
1718import torch
18- import torch .optim
19- import torch .utils .data
20- import torch .utils .data .distributed
21- import torchvision .datasets as datasets
22- import torchvision .models as models
23- import torchvision .transforms as transforms
2419from fastdownload import FastDownload
2520from rich .progress import track
2621from torch import nn
2722from torch .jit import TracerWarning
2823from torch .utils .data import DataLoader
24+ from torchvision import datasets
25+ from torchvision import transforms
26+ from torchvision .models import resnet18
2927
3028import nncf
31- import nncf .parameters
32- import nncf .torch
33- import nncf .torch .function_hook
34- import nncf .torch .function_hook .prune
35- import nncf .torch .function_hook .prune .prune_model
3629from nncf .parameters import PruneMode
37- from nncf .torch .function_hook .prune .magnitude .schedulers import MultiStepMagnitudePruningScheduler
30+ from nncf .torch .function_hook .pruning .magnitude .schedulers import MultiStepMagnitudePruningScheduler
31+ from nncf .torch .function_hook .pruning .rb .losses import RBLoss
32+ from nncf .torch .function_hook .pruning .rb .schedulers import MultiStepRBPruningScheduler
3833
3934warnings .filterwarnings ("ignore" , category = TracerWarning )
4035warnings .filterwarnings ("ignore" , category = UserWarning )
4136
4237BASE_MODEL_NAME = "resnet18"
4338IMAGE_SIZE = 64
4439BATCH_SIZE = 128
45- TRAINING_EPOCHS = 2
4640
4741
4842ROOT = Path (__file__ ).parent .resolve ()
49- BEST_CKPT_NAME = "resnet18_int8_best.pt"
5043CHECKPOINT_URL = (
5144 "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth"
5245)
5346DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
5447DATASET_PATH = Path ().home () / ".cache" / "nncf" / "datasets"
5548
5649
50+ def get_argument_parser () -> ArgumentParser :
51+ parser = ArgumentParser ()
52+ parser .add_argument (
53+ "--mode" ,
54+ type = str ,
55+ choices = ["magnitude" , "rb" ],
56+ default = "magnitude" ,
57+ help = "Pruning mode to use. Choices are: magnitude, rb. Default is magnitude." ,
58+ )
59+ return parser
60+
61+
5762def download_dataset () -> Path :
5863 downloader = FastDownload (base = DATASET_PATH .resolve (), archive = "downloaded" , data = "extracted" )
5964 return downloader .get (DATASET_URL )
@@ -66,10 +71,10 @@ def load_checkpoint(model: nn.Module) -> tuple[nn.Module, float]:
6671
6772
6873def get_resnet18_model (device : torch .device ) -> nn .Module :
69- num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet
70- model = models .resnet18 (weights = None )
74+ model = resnet18 (weights = None )
7175 # Update the last FC layer for Tiny ImageNet number of classes.
72- model .fc = nn .Linear (in_features = 512 , out_features = num_classes , bias = True )
76+ # 200 is for Tiny ImageNet, default is 1000 for ImageNet
77+ model .fc = nn .Linear (in_features = 512 , out_features = 200 , bias = True )
7378 model .to (device )
7479 return model
7580
@@ -78,6 +83,7 @@ def train_epoch(
7883 train_loader : DataLoader ,
7984 model : nn .Module ,
8085 criterion : nn .Module ,
86+ rb_loss : RBLoss ,
8187 optimizer : torch .optim .Optimizer ,
8288 device : torch .device ,
8389):
@@ -91,50 +97,34 @@ def train_epoch(
9197 # Compute output.
9298 output = model (images )
9399 loss = criterion (output , target )
94-
100+ if rb_loss is not None :
101+ loss += rb_loss ()
95102 # Compute gradient and do opt step.
96103 optimizer .zero_grad ()
97104 loss .backward ()
98105 optimizer .step ()
99106
100107
101- def validate (val_loader : DataLoader , model : nn .Module , device : torch .device ) -> float :
102- top1_sum = 0.0
103-
108+ @torch .no_grad ()
109+ def validate (val_loader : torch .utils .data .DataLoader , model : torch .nn .Module , device : torch .device ) -> float :
104110 # Switch to evaluate mode.
105111 model .eval ()
106112
107- with torch .no_grad ():
108- for images , target in track (val_loader , total = len (val_loader ), description = "Validation:" ):
109- images = images .to (device )
110- target = target .to (device )
111-
112- # Compute output.
113- output = model (images )
114-
115- # Measure accuracy and record loss.
116- [acc1 ] = accuracy (output , target , topk = (1 ,))
117- top1_sum += acc1 .item ()
118-
119- num_samples = len (val_loader )
120- top1_avg = top1_sum / num_samples
121- return top1_avg
113+ correct = 0
114+ total = 0
122115
116+ for images , target in track (val_loader , total = len (val_loader ), description = "Validation:" ):
117+ images = images .to (device )
118+ target = target .to (device )
123119
124- def accuracy (output : torch .Tensor , target : torch .tensor , topk : tuple [int , ...] = (1 ,)):
125- with torch .no_grad ():
126- maxk = max (topk )
127- batch_size = target .size (0 )
120+ output = model (images )
128121
129- _ , pred = output .topk ( maxk , 1 , True , True )
130- pred = pred . t ()
131- correct = pred . eq ( target .view ( 1 , - 1 ). expand_as ( pred ) )
122+ _ , preds = output .max ( 1 )
123+ correct += preds . eq ( target ). sum (). item ()
124+ total += target .size ( 0 )
132125
133- res = []
134- for k in topk :
135- correct_k = correct [:k ].reshape (- 1 ).float ().sum (0 , keepdim = True )
136- res .append (correct_k .mul_ (100.0 / batch_size ))
137- return res
126+ accuracy1 = 100.0 * correct / total
127+ return accuracy1
138128
139129
140130def create_data_loaders () -> tuple [DataLoader , DataLoader ]:
@@ -151,23 +141,12 @@ def create_data_loaders() -> tuple[DataLoader, DataLoader]:
151141 train_dataset = datasets .ImageFolder (
152142 train_dir ,
153143 transforms .Compose (
154- [
155- transforms .Resize (IMAGE_SIZE ),
156- transforms .RandomHorizontalFlip (),
157- transforms .ToTensor (),
158- normalize ,
159- ]
144+ [transforms .Resize (IMAGE_SIZE ), transforms .RandomHorizontalFlip (), transforms .ToTensor (), normalize ]
160145 ),
161146 )
162147 val_dataset = datasets .ImageFolder (
163148 val_dir ,
164- transforms .Compose (
165- [
166- transforms .Resize (IMAGE_SIZE ),
167- transforms .ToTensor (),
168- normalize ,
169- ]
170- ),
149+ transforms .Compose ([transforms .Resize (IMAGE_SIZE ), transforms .ToTensor (), normalize ]),
171150 )
172151
173152 train_loader = DataLoader (
@@ -200,7 +179,10 @@ def prepare_tiny_imagenet_200(dataset_dir: Path) -> None:
200179 val_images_dir .rmdir ()
201180
202181
203- def main ():
182+ def main () -> float :
183+ args = get_argument_parser ().parse_args ()
184+ pruning_mode = args .mode
185+
204186 torch .manual_seed (0 )
205187 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
206188 print (f"Using { device } device" )
@@ -212,51 +194,64 @@ def main():
212194 model = get_resnet18_model (device )
213195 model , acc1_fp32 = load_checkpoint (model )
214196
215- print (f"Accuracy@1 of original FP32 model: { acc1_fp32 } " )
197+ print (f"Accuracy@1 of original FP32 model: { acc1_fp32 :.2f } " )
216198
217199 train_loader , val_loader = create_data_loaders ()
218200 example_input = torch .rand (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ).to (device )
219201
220202 ###############################################################################
221203 # Step 2: Prune model
222- print (os .linesep + "[Step 2] Prune model" )
223-
224- # Unstructured pruning with 70% sparsity ratio
225- pruned_model = nncf .prune (
226- model ,
227- mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL ,
228- ratio = 0.7 ,
229- ignored_scope = nncf .IgnoredScope (),
230- examples_inputs = example_input ,
231- )
232-
233- acc1_init = validate (val_loader , pruned_model , device )
234-
235- print (f"Accuracy@1 of pruned model with 0.7 pruning ratio without fine-tuning: { acc1_init :.3f} " )
236-
237- ###############################################################################
238- # Step 3: Fine tune with multi step sparsity scheduler
239- print (os .linesep + "[Step 3] Fine tune with multi step sparsity scheduler" )
204+ print (os .linesep + "[Step 2]: Prune model and specify training parameters" )
205+
206+ if pruning_mode == "magnitude" :
207+ pruned_model = nncf .prune (
208+ model ,
209+ mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL ,
210+ ratio = 0.7 ,
211+ ignored_scope = nncf .IgnoredScope (),
212+ examples_inputs = example_input ,
213+ )
214+ num_epochs = 2
215+ rb_loss = None
216+ scheduler = MultiStepMagnitudePruningScheduler (
217+ model = model , mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL , steps = {0 : 0.5 , 1 : 0.7 }
218+ )
219+ optimizer = torch .optim .Adam (pruned_model .parameters (), lr = 1e-5 )
220+ else :
221+ pruned_model = nncf .prune (
222+ model ,
223+ mode = PruneMode .UNSTRUCTURED_REGULARIZATION_BASED ,
224+ ignored_scope = nncf .IgnoredScope (),
225+ examples_inputs = example_input ,
226+ )
227+ num_epochs = 30
228+ rb_loss = RBLoss (pruned_model , target_ratio = 0.7 , p = 0.1 ).to (device )
229+ scheduler = MultiStepRBPruningScheduler (rb_loss , steps = {0 : 0.3 , 5 : 0.5 , 10 : 0.7 })
230+
231+ # Set higher lr for mask parameters to achieve the target pruning ratio faster
232+ mask_params = [p for n , p in pruned_model .named_parameters () if "mask" in n ]
233+ model_params = [p for n , p in pruned_model .named_parameters () if "mask" not in n ]
234+ optimizer = torch .optim .Adam (
235+ [
236+ {"params" : model_params , "lr" : 1e-5 },
237+ {"params" : mask_params , "lr" : 1e-2 , "weight_decay" : 0.0 },
238+ ]
239+ )
240240
241- # Define loss function (criterion) and optimizer.
242241 criterion = nn .CrossEntropyLoss ().to (device )
243- compression_lr = 1e-5
244- optimizer = torch .optim .Adam (pruned_model .parameters (), lr = compression_lr )
245242
246- # Create prune scheduler with multi steps strategy
247- pruning_scheduler = MultiStepMagnitudePruningScheduler (
248- pruned_model , mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL , steps = {0 : 0.6 , 1 : 0.7 }
249- )
243+ ###############################################################################
244+ # Step 3: Fine tune
245+ print (os .linesep + "[Step 3] Fine tune with multi step pruning ratio scheduler" )
250246
251- for epoch in range (2 ):
247+ for epoch in range (num_epochs ):
252248 print (os .linesep + f"Train epoch: { epoch } " )
249+ scheduler .step ()
250+ train_epoch (train_loader , pruned_model , criterion , rb_loss , optimizer , device = device )
253251
254- pruning_scheduler .step ()
255-
256- train_epoch (train_loader , pruned_model , criterion , optimizer , device = device )
257252 acc1 = validate (val_loader , pruned_model , device )
258- # Show statistics of pruning
259- print (f"Accuracy@1 of pruned model after { epoch } epoch ratio { pruning_scheduler . current_ratio } : { acc1 :.3f} " )
253+ print ( f"Current pruning ratio: { scheduler . current_ratio :.3f } " )
254+ print (f"Accuracy@1 of pruned model after { epoch } epoch: { acc1 :.3f} " )
260255
261256 ###############################################################################
262257 # Step 4: Export models
0 commit comments