1212import os
1313import warnings
1414from argparse import ArgumentParser
15+ from argparse import RawTextHelpFormatter
1516from pathlib import Path
1617
1718import openvino as ov
4849
4950
5051def get_argument_parser () -> ArgumentParser :
51- parser = ArgumentParser ()
52+ parser = ArgumentParser (formatter_class = RawTextHelpFormatter )
5253 parser .add_argument (
5354 "--mode" ,
5455 type = str ,
55- choices = ["magnitude" , "rb" ],
56- default = "magnitude" ,
57- help = "Pruning mode to use. Choices are: magnitude, rb. Default is magnitude." ,
56+ choices = ["mag" , "mag_bn" , "rb" ],
57+ default = "mag" ,
58+ help = (
59+ "Pruning mode to use. Choices are:\n "
60+ " - mag: Magnitude-based pruning with fine-tuning (default).\n "
61+ " - mag_bn: Magnitude-based pruning with BatchNorm adaptation without fine-tuning.\n "
62+ " - rb: Regularization-based pruning with fine-tuning.\n "
63+ ),
5864 )
5965 return parser
6066
@@ -82,13 +88,12 @@ def get_resnet18_model(device: torch.device) -> nn.Module:
8288def train_epoch (
8389 train_loader : DataLoader ,
8490 model : nn .Module ,
85- criterion : nn .Module ,
8691 rb_loss : RBLoss ,
8792 optimizer : torch .optim .Optimizer ,
8893 device : torch .device ,
8994):
90- # Switch to train mode.
9195 model .train ()
96+ criterion = nn .CrossEntropyLoss ().to (device )
9297
9398 for images , target in track (train_loader , total = len (train_loader ), description = "Fine tuning:" ):
9499 images = images .to (device )
@@ -107,7 +112,6 @@ def train_epoch(
107112
108113@torch .no_grad ()
109114def validate (val_loader : torch .utils .data .DataLoader , model : torch .nn .Module , device : torch .device ) -> float :
110- # Switch to evaluate mode.
111115 model .eval ()
112116
113117 correct = 0
@@ -201,14 +205,20 @@ def main() -> float:
201205
202206 ###############################################################################
203207 # Step 2: Prune model
204- print (os .linesep + "[Step 2]: Prune model and specify training parameters" )
208+ print (os .linesep + "[Step 2] Prune model and specify training parameters" )
205209
206- if pruning_mode == "magnitude" :
210+ if pruning_mode == "mag_bn" :
211+ pruned_model = nncf .prune (
212+ model ,
213+ mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL ,
214+ ratio = 0.6 ,
215+ examples_inputs = example_input ,
216+ )
217+ elif pruning_mode == "mag" :
207218 pruned_model = nncf .prune (
208219 model ,
209220 mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL ,
210221 ratio = 0.7 ,
211- ignored_scope = nncf .IgnoredScope (),
212222 examples_inputs = example_input ,
213223 )
214224 num_epochs = 2
@@ -217,11 +227,10 @@ def main() -> float:
217227 model = model , mode = PruneMode .UNSTRUCTURED_MAGNITUDE_GLOBAL , steps = {0 : 0.5 , 1 : 0.7 }
218228 )
219229 optimizer = torch .optim .Adam (pruned_model .parameters (), lr = 1e-5 )
220- else :
230+ elif pruning_mode == "rb" :
221231 pruned_model = nncf .prune (
222232 model ,
223233 mode = PruneMode .UNSTRUCTURED_REGULARIZATION_BASED ,
224- ignored_scope = nncf .IgnoredScope (),
225234 examples_inputs = example_input ,
226235 )
227236 num_epochs = 30
@@ -237,32 +246,52 @@ def main() -> float:
237246 {"params" : mask_params , "lr" : 1e-2 , "weight_decay" : 0.0 },
238247 ]
239248 )
240-
241- criterion = nn .CrossEntropyLoss ().to (device )
249+ else :
250+ msg = f"Unsupported pruning mode: { pruning_mode } , please choose from ['mag', 'mag_bn', 'rb']"
251+ raise ValueError (msg )
242252
243253 ###############################################################################
244254 # Step 3: Fine tune
245255 print (os .linesep + "[Step 3] Fine tune with multi step pruning ratio scheduler" )
246256
247- for epoch in range (num_epochs ):
248- print (os .linesep + f"Train epoch: { epoch } " )
249- scheduler .step ()
250- train_epoch (train_loader , pruned_model , criterion , rb_loss , optimizer , device = device )
257+ if pruning_mode == "mag_bn" :
258+ acc1_before = validate (val_loader , pruned_model , device )
259+ print (f"Accuracy@1 of pruned model before BatchNorm adaptation: { acc1_before :.3f} " )
260+
261+ def transform_fn (batch : tuple [torch .Tensor , int ]) -> torch .Tensor :
262+ inputs , _ = batch
263+ return inputs .to (device = device )
264+
265+ calibration_dataset = nncf .Dataset (train_loader , transform_func = transform_fn )
266+
267+ pruned_model = nncf .batch_norm_adaptation (
268+ pruned_model ,
269+ calibration_dataset = calibration_dataset ,
270+ num_iterations = 200 ,
271+ )
251272
252273 acc1 = validate (val_loader , pruned_model , device )
253- print (f"Current pruning ratio: { scheduler .current_ratio :.3f} " )
254- print (f"Accuracy@1 of pruned model after { epoch } epoch: { acc1 :.3f} " )
274+ print (f"Accuracy@1 of pruned model after BatchNorm adaptation: { acc1 :.3f} " )
275+ else :
276+ for epoch in range (num_epochs ):
277+ print (os .linesep + f"Train epoch: { epoch } " )
278+ scheduler .step ()
279+ train_epoch (train_loader , pruned_model , rb_loss , optimizer , device = device )
280+
281+ acc1 = validate (val_loader , pruned_model , device )
282+ print (f"Current pruning ratio: { scheduler .current_ratio :.3f} " )
283+ print (f"Accuracy@1 of pruned model after { epoch } epoch: { acc1 :.3f} " )
255284
256285 ###############################################################################
257286 # Step 4: Print per tensor pruning statistics
258- print (os .linesep + "[Step 4]: Pruning statistics" )
287+ print (os .linesep + "[Step 4] Pruning statistics" )
259288
260289 pruning_stat = nncf .pruning_statistic (pruned_model )
261290 print (pruning_stat )
262291
263292 ###############################################################################
264293 # Step 5: Export models
265- print (os .linesep + "[Step 5]: Export models" )
294+ print (os .linesep + "[Step 5] Export models" )
266295 ir_path = ROOT / f"{ BASE_MODEL_NAME } _pruned.xml"
267296 ov_model = ov .convert_model (pruned_model .cpu (), example_input = example_input .cpu (), input = tuple (example_input .shape ))
268297 ov .save_model (ov_model , ir_path , compress_to_fp16 = False )
0 commit comments