66# Unless required by applicable law or agreed to in writing,
77# software distributed under the License is distributed on an
88# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9- """Finetune a pretrained fp32 with int8 quantization aware training(QAT )"""
9+ """Finetune a pretrained fp32 with int8 post train quantization(calibration )"""
1010import argparse
1111import collections
12- import multiprocessing as mp
1312import numbers
1413import os
15- import bisect
1614import time
1715
16+ # pylint: disable=import-error
17+ import models
18+
1819import megengine as mge
1920import megengine .data as data
2021import megengine .data .transform as T
2122import megengine .distributed as dist
2223import megengine .functional as F
23- import megengine .jit as jit
24- import megengine .optimizer as optim
2524import megengine .quantization as Q
26-
27- import config
28- import models
25+ from megengine .quantization .quantize import enable_observer , quantize , quantize_qat
2926
3027logger = mge .get_logger (__name__ )
31- # from imagenet_nori_dataset import ImageNetNoriDataset
32- from megengine .quantization .quantize import enable_observer , quantize , quantize_qat
28+
3329
3430def main ():
3531 parser = argparse .ArgumentParser ()
3632 parser .add_argument ("-a" , "--arch" , default = "resnet18" , type = str )
3733 parser .add_argument ("-d" , "--data" , default = None , type = str )
3834 parser .add_argument ("-s" , "--save" , default = "/data/models" , type = str )
39- parser .add_argument ("-c" , "--checkpoint" , default = None , type = str ,
40- help = "pretrained model to finetune" )
41-
42- parser .add_argument ("-m" , "--mode" , default = "qat" , type = str ,
43- choices = ["normal" , "qat" , "quantized" , "calibration" ],
44- help = "Quantization Mode\n "
45- "normal: no quantization, using float32\n "
46- "qat: quantization aware training, simulate int8\n "
47- "calibration: calibration\n "
48- "quantized: convert mode to int8 quantized, inference only" )
35+ parser .add_argument (
36+ "-c" ,
37+ "--checkpoint" ,
38+ default = None ,
39+ type = str ,
40+ help = "pretrained model to finetune" ,
41+ )
4942
5043 parser .add_argument ("-n" , "--ngpus" , default = None , type = int )
5144 parser .add_argument ("-w" , "--workers" , default = 4 , type = int )
5245 parser .add_argument ("--report-freq" , default = 50 , type = int )
5346 args = parser .parse_args ()
5447
55- world_size = mge .get_device_count ("gpu" ) if args .ngpus is None else args .ngpus
56-
57- if world_size > 1 :
58- # start distributed training, dispatch sub-processes
59- mp .set_start_method ("spawn" )
60- processes = []
61- for rank in range (world_size ):
62- p = mp .Process (target = worker , args = (rank , world_size , args ))
63- p .start ()
64- processes .append (p )
65-
66- for p in processes :
67- p .join ()
68- else :
69- worker (0 , 1 , args )
48+ world_size = (
49+ dist .helper .get_device_count_by_fork ("gpu" )
50+ if args .ngpus is None
51+ else args .ngpus
52+ )
53+ world_size = 1 if world_size == 0 else world_size
54+ if world_size != 1 :
55+ logger .warning (
56+ "Calibration only supports single GPU now, %d provided" , world_size
57+ )
58+ proc_func = dist .launcher (worker ) if world_size > 1 else worker
59+ proc_func (world_size , args )
7060
7161
7262def get_parameters (model , cfg ):
7363 if isinstance (cfg .WEIGHT_DECAY , numbers .Number ):
74- return {"params" : model .parameters (requires_grad = True ),
75- "weight_decay" : cfg .WEIGHT_DECAY }
64+ return {
65+ "params" : model .parameters (requires_grad = True ),
66+ "weight_decay" : cfg .WEIGHT_DECAY ,
67+ }
7668
7769 groups = collections .defaultdict (list ) # weight_decay -> List[param]
7870 for pname , p in model .named_parameters (requires_grad = True ):
7971 wd = cfg .WEIGHT_DECAY (pname , p )
8072 groups [wd ].append (p )
8173 groups = [
82- {"params" : params , "weight_decay" : wd }
83- for wd , params in groups .items ()
74+ {"params" : params , "weight_decay" : wd } for wd , params in groups .items ()
8475 ] # List[{param, weight_decay}]
8576 return groups
8677
8778
88- def worker (rank , world_size , args ):
79+ def worker (world_size , args ):
8980 # pylint: disable=too-many-statements
9081
82+ rank = dist .get_rank ()
9183 if world_size > 1 :
9284 # Initialize distributed process group
9385 logger .info ("init distributed process group {} / {}" .format (rank , world_size ))
94- dist .init_process_group (
95- master_ip = "localhost" ,
96- master_port = 23456 ,
97- world_size = world_size ,
98- rank = rank ,
99- dev = rank ,
100- )
10186
102- save_dir = os .path .join (args .save , args .arch + "." + args . mode )
87+ save_dir = os .path .join (args .save , args .arch + "." + "calibration" )
10388 if not os .path .exists (save_dir ):
10489 os .makedirs (save_dir , exist_ok = True )
10590 mge .set_log_file (os .path .join (save_dir , "log.txt" ))
10691
10792 model = models .__dict__ [args .arch ]()
108- cfg = config .get_finetune_config (args .arch )
10993
110- cfg .LEARNING_RATE *= world_size # scale learning rate in distributed training
111- total_batch_size = cfg .BATCH_SIZE * world_size
112- steps_per_epoch = 1280000 // total_batch_size
113- total_steps = steps_per_epoch * cfg .EPOCHS
114-
11594 # load calibration model
11695 assert args .checkpoint
11796 logger .info ("Load pretrained weights from %s" , args .checkpoint )
@@ -121,70 +100,64 @@ def worker(rank, world_size, args):
121100
122101 # Build valid datasets
123102 valid_dataset = data .dataset .ImageNet (args .data , train = False )
124- # valid_dataset = ImageNetNoriDataset(args.data)
125103 valid_sampler = data .SequentialSampler (
126104 valid_dataset , batch_size = 100 , drop_last = False
127105 )
128106 valid_queue = data .DataLoader (
129107 valid_dataset ,
130108 sampler = valid_sampler ,
131109 transform = T .Compose (
132- [
133- T .Resize (256 ),
134- T .CenterCrop (224 ),
135- T .Normalize (mean = 128 ),
136- T .ToMode ("CHW" ),
137- ]
110+ [T .Resize (256 ), T .CenterCrop (224 ), T .Normalize (mean = 128 ), T .ToMode ("CHW" )]
138111 ),
139112 num_workers = args .workers ,
140113 )
141114
142115 # calibration
143116 model .fc .disable_quantize ()
144117 model = quantize_qat (model , qconfig = Q .calibration_qconfig )
145-
118+
146119 # calculate scale
147- @jit .trace (symbolic = True )
148120 def calculate_scale (image , label ):
149121 model .eval ()
150122 enable_observer (model )
151123 logits = model (image )
152- loss = F .cross_entropy_with_softmax (logits , label , label_smooth = 0.1 )
153- acc1 , acc5 = F .accuracy (logits , label , (1 , 5 ))
124+ loss = F .loss . cross_entropy (logits , label , label_smooth = 0.1 )
125+ acc1 , acc5 = F .topk_accuracy (logits , label , (1 , 5 ))
154126 if dist .is_distributed (): # all_reduce_mean
155- loss = dist .all_reduce_sum (loss , "valid_loss" ) / dist .get_world_size ()
156- acc1 = dist .all_reduce_sum (acc1 , "valid_acc1" ) / dist .get_world_size ()
157- acc5 = dist .all_reduce_sum (acc5 , "valid_acc5" ) / dist .get_world_size ()
127+ loss = dist .functional . all_reduce_sum (loss ) / dist .get_world_size ()
128+ acc1 = dist .functional . all_reduce_sum (acc1 ) / dist .get_world_size ()
129+ acc5 = dist .functional . all_reduce_sum (acc5 ) / dist .get_world_size ()
158130 return loss , acc1 , acc5
159-
160- # model.fc.disable_quantize()
131+
161132 infer (calculate_scale , valid_queue , args )
162133
163134 # quantized
164135 model = quantize (model )
165136
166137 # eval quantized model
167- @jit .trace (symbolic = True )
168138 def eval_func (image , label ):
169139 model .eval ()
170140 logits = model (image )
171- loss = F .cross_entropy_with_softmax (logits , label , label_smooth = 0.1 )
172- acc1 , acc5 = F .accuracy (logits , label , (1 , 5 ))
141+ loss = F .loss . cross_entropy (logits , label , label_smooth = 0.1 )
142+ acc1 , acc5 = F .topk_accuracy (logits , label , (1 , 5 ))
173143 if dist .is_distributed (): # all_reduce_mean
174- loss = dist .all_reduce_sum (loss , "valid_loss" ) / dist .get_world_size ()
175- acc1 = dist .all_reduce_sum (acc1 , "valid_acc1" ) / dist .get_world_size ()
176- acc5 = dist .all_reduce_sum (acc5 , "valid_acc5" ) / dist .get_world_size ()
144+ loss = dist .functional . all_reduce_sum (loss ) / dist .get_world_size ()
145+ acc1 = dist .functional . all_reduce_sum (acc1 ) / dist .get_world_size ()
146+ acc5 = dist .functional . all_reduce_sum (acc5 ) / dist .get_world_size ()
177147 return loss , acc1 , acc5
178-
148+
179149 _ , valid_acc , valid_acc5 = infer (eval_func , valid_queue , args )
180150 logger .info ("TEST %f, %f" , valid_acc , valid_acc5 )
181151
182152 # save quantized model
183153 mge .save (
184154 {"step" : - 1 , "state_dict" : model .state_dict ()},
185- os .path .join (save_dir , "checkpoint-calibration.pkl" )
155+ os .path .join (save_dir , "checkpoint-calibration.pkl" ),
156+ )
157+ logger .info (
158+ "save in {}" .format (os .path .join (save_dir , "checkpoint-calibration.pkl" ))
186159 )
187- logger . info ( "save in {}" . format ( os . path . join ( save_dir , "checkpoint-calibration.pkl" )))
160+
188161
189162def infer (model , data_queue , args ):
190163 objs = AverageMeter ("Loss" )
@@ -195,8 +168,8 @@ def infer(model, data_queue, args):
195168 t = time .time ()
196169 for step , (image , label ) in enumerate (data_queue ):
197170 n = image .shape [0 ]
198- image = image . astype ( "float32" ) # convert np.uint8 to float32
199- label = label . astype ( "int32" )
171+ image = mge . tensor ( image , dtype = "float32" )
172+ label = mge . tensor ( label , dtype = "int32" )
200173
201174 loss , acc1 , acc5 = model (image , label )
202175
@@ -207,9 +180,8 @@ def infer(model, data_queue, args):
207180 t = time .time ()
208181
209182 if step % args .report_freq == 0 and dist .get_rank () == 0 :
210- logger .info ("Step %d, %s %s %s %s" ,
211- step , objs , top1 , top5 , total_time )
212-
183+ logger .info ("Step %d, %s %s %s %s" , step , objs , top1 , top5 , total_time )
184+
213185 # break
214186 if step == args .report_freq :
215187 break
0 commit comments