@@ -48,44 +48,85 @@ def process_checkpoint(in_file, out_file):
4848 return final_file
4949
5050
51- def get_final_epoch (config ):
51+ def is_by_epoch (config ):
5252 cfg = mmcv .Config .fromfile ('./configs/' + config )
53- return cfg .runner .max_epochs
53+ return cfg .runner .type == 'EpochBasedRunner'
5454
5555
56- def get_best_epoch (exp_dir ):
57- best_epoch_full_path = list (
56+ def get_final_epoch_or_iter (config ):
57+ cfg = mmcv .Config .fromfile ('./configs/' + config )
58+ if cfg .runner .type == 'EpochBasedRunner' :
59+ return cfg .runner .max_epochs
60+ else :
61+ return cfg .runner .max_iters
62+
63+
64+ def get_best_epoch_or_iter (exp_dir ):
65+ best_epoch_iter_full_path = list (
5866 sorted (glob .glob (osp .join (exp_dir , 'best_*.pth' ))))[- 1 ]
59- best_epoch_model_path = best_epoch_full_path .split ('/' )[- 1 ]
60- best_epoch = best_epoch_model_path .split ('_' )[- 1 ].split ('.' )[0 ]
61- return best_epoch_model_path , int (best_epoch )
67+ best_epoch_or_iter_model_path = best_epoch_iter_full_path .split ('/' )[- 1 ]
68+ best_epoch_or_iter = best_epoch_or_iter_model_path .\
69+ split ('_' )[- 1 ].split ('.' )[0 ]
70+ return best_epoch_or_iter_model_path , int (best_epoch_or_iter )
6271
6372
64- def get_real_epoch (config ):
73+ def get_real_epoch_or_iter (config ):
6574 cfg = mmcv .Config .fromfile ('./configs/' + config )
66- epoch = cfg .runner .max_epochs
67- if cfg .data .train .type == 'RepeatDataset' :
68- epoch *= cfg .data .train .times
69- return epoch
75+ if cfg .runner .type == 'EpochBasedRunner' :
76+ epoch = cfg .runner .max_epochs
77+ if cfg .data .train .type == 'RepeatDataset' :
78+ epoch *= cfg .data .train .times
79+ return epoch
80+ else :
81+ return cfg .runner .max_iters
7082
7183
72- def get_final_results (log_json_path , epoch , results_lut ):
84+ def get_final_results (log_json_path ,
85+ epoch_or_iter ,
86+ results_lut ,
87+ by_epoch = True ):
7388 result_dict = dict ()
89+ last_val_line = None
90+ last_train_line = None
91+ last_val_line_idx = - 1
92+ last_train_line_idx = - 1
7493 with open (log_json_path , 'r' ) as f :
75- for line in f .readlines ():
94+ for i , line in enumerate ( f .readlines () ):
7695 log_line = json .loads (line )
7796 if 'mode' not in log_line .keys ():
7897 continue
7998
80- if log_line ['mode' ] == 'train' and log_line ['epoch' ] == epoch :
81- result_dict ['memory' ] = log_line ['memory' ]
82-
83- if log_line ['mode' ] == 'val' and log_line ['epoch' ] == epoch :
84- result_dict .update ({
85- key : log_line [key ]
86- for key in results_lut if key in log_line
87- })
88- return result_dict
99+ if by_epoch :
100+ if (log_line ['mode' ] == 'train'
101+ and log_line ['epoch' ] == epoch_or_iter ):
102+ result_dict ['memory' ] = log_line ['memory' ]
103+
104+ if (log_line ['mode' ] == 'val'
105+ and log_line ['epoch' ] == epoch_or_iter ):
106+ result_dict .update ({
107+ key : log_line [key ]
108+ for key in results_lut if key in log_line
109+ })
110+ return result_dict
111+ else :
112+ if log_line ['mode' ] == 'train' :
113+ last_train_line_idx = i
114+ last_train_line = log_line
115+
116+ if log_line and log_line ['mode' ] == 'val' :
117+ last_val_line_idx = i
118+ last_val_line = log_line
119+
120+ # bug: max_iters = 768, last_train_line['iter'] = 750
121+ assert last_val_line_idx == last_train_line_idx + 1 , \
122+ 'Log file is incomplete'
123+ result_dict ['memory' ] = last_train_line ['memory' ]
124+ result_dict .update ({
125+ key : last_val_line [key ]
126+ for key in results_lut if key in last_val_line
127+ })
128+
129+ return result_dict
89130
90131
91132def get_dataset_name (config ):
@@ -116,10 +157,12 @@ def convert_model_info_to_pwc(model_infos):
116157
117158 # get metadata
118159 memory = round (model ['results' ]['memory' ] / 1024 , 1 )
119- epochs = get_real_epoch (model ['config' ])
120160 meta_data = OrderedDict ()
121161 meta_data ['Training Memory (GB)' ] = memory
122- meta_data ['Epochs' ] = epochs
162+ if 'epochs' in model :
163+ meta_data ['Epochs' ] = get_real_epoch_or_iter (model ['config' ])
164+ else :
165+ meta_data ['Iterations' ] = get_real_epoch_or_iter (model ['config' ])
123166 pwc_model_info ['Metadata' ] = meta_data
124167
125168 # get dataset name
@@ -200,12 +243,14 @@ def main():
200243 model_infos = []
201244 for used_config in used_configs :
202245 exp_dir = osp .join (models_root , used_config )
246+ by_epoch = is_by_epoch (used_config )
203247 # check whether the exps is finished
204248 if args .best is True :
205- final_model , final_epoch = get_best_epoch (exp_dir )
249+ final_model , final_epoch_or_iter = get_best_epoch_or_iter (exp_dir )
206250 else :
207- final_epoch = get_final_epoch (used_config )
208- final_model = 'epoch_{}.pth' .format (final_epoch )
251+ final_epoch_or_iter = get_final_epoch_or_iter (used_config )
252+ final_model = '{}_{}.pth' .format ('epoch' if by_epoch else 'iter' ,
253+ final_epoch_or_iter )
209254
210255 model_path = osp .join (exp_dir , final_model )
211256 # skip if the model is still training
@@ -225,21 +270,23 @@ def main():
225270 for i , key in enumerate (results_lut ):
226271 if 'mAP' not in key and 'PQ' not in key :
227272 results_lut [i ] = key + 'm_AP'
228- model_performance = get_final_results (log_json_path , final_epoch ,
229- results_lut )
273+ model_performance = get_final_results (log_json_path ,
274+ final_epoch_or_iter , results_lut ,
275+ by_epoch )
230276
231277 if model_performance is None :
232278 continue
233279
234280 model_time = osp .split (log_txt_path )[- 1 ].split ('.' )[0 ]
235- model_infos .append (
236- dict (
237- config = used_config ,
238- results = model_performance ,
239- epochs = final_epoch ,
240- model_time = model_time ,
241- final_model = final_model ,
242- log_json_path = osp .split (log_json_path )[- 1 ]))
281+ model_info = dict (
282+ config = used_config ,
283+ results = model_performance ,
284+ model_time = model_time ,
285+ final_model = final_model ,
286+ log_json_path = osp .split (log_json_path )[- 1 ])
287+ model_info ['epochs' if by_epoch else 'iterations' ] = \
288+ final_epoch_or_iter
289+ model_infos .append (model_info )
243290
244291 # publish model for each checkpoint
245292 publish_model_infos = []
0 commit comments