4848 has_dynamo = False
4949
5050
51+ _FMT_EXT = {
52+ 'json' : '.json' ,
53+ 'json-split' : '.json' ,
54+ 'parquet' : '.parquet' ,
55+ 'csv' : '.csv' ,
56+ }
57+
5158torch .backends .cudnn .benchmark = True
5259_logger = logging .getLogger ('inference' )
5360
103110 help = 'use Native AMP for mixed precision training' )
104111parser .add_argument ('--amp-dtype' , default = 'float16' , type = str ,
105112 help = 'lower precision AMP dtype (default: float16)' )
106- parser .add_argument ('--use-ema' , dest = 'use_ema' , action = 'store_true' ,
107- help = 'use ema version of weights if present' )
108113parser .add_argument ('--fuser' , default = '' , type = str ,
109114 help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
110115parser .add_argument ('--dynamo-backend' , default = None , type = str ,
122127 help = 'folder for output results' )
123128parser .add_argument ('--results-file' , type = str , default = None ,
124129 help = 'results filename (relative to results-dir)' )
125- parser .add_argument ('--results-format' , type = str , default = 'csv' ,
130+ parser .add_argument ('--results-format' , type = str , nargs = '+' , default = [ 'csv' ] ,
126131 help = 'results format (one of "csv", "json", "json-split", "parquet")' )
132+ parser .add_argument ('--results-separate-col' , action = 'store_true' , default = False ,
133+ help = 'separate output columns per result index.' )
127134parser .add_argument ('--topk' , default = 1 , type = int ,
128135 metavar = 'N' , help = 'Top-k to output to CSV' )
129136parser .add_argument ('--fullname' , action = 'store_true' , default = False ,
130137 help = 'use full sample name in output (not just basename).' )
131- parser .add_argument ('--indices-name' , default = 'index' ,
138+ parser .add_argument ('--filename-col' , default = 'filename' ,
139+ help = 'name for filename / sample name column' )
140+ parser .add_argument ('--index-col' , default = 'index' ,
132141 help = 'name for output indices column(s)' )
133- parser .add_argument ('--outputs-name ' , default = None ,
142+ parser .add_argument ('--output-col ' , default = None ,
134143 help = 'name for logit/probs output column(s)' )
135- parser .add_argument ('--outputs -type' , default = 'prob' ,
144+ parser .add_argument ('--output -type' , default = 'prob' ,
136145 help = 'output type colum ("prob" for probabilities, "logit" for raw logits)' )
137- parser .add_argument ('--separate-columns' , action = 'store_true' , default = False ,
138- help = 'separate output columns per result index.' )
139- parser .add_argument ('--exclude-outputs' , action = 'store_true' , default = False ,
146+ parser .add_argument ('--exclude-output' , action = 'store_true' , default = False ,
140147 help = 'exclude logits/probs from results, just indices. topk must be set !=0.' )
141148
142149
@@ -179,9 +186,6 @@ def main():
179186 assert hasattr (model , 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.'
180187 args .num_classes = model .num_classes
181188
182- if args .checkpoint :
183- load_checkpoint (model , args .checkpoint , args .use_ema )
184-
185189 _logger .info (
186190 f'Model { args .model } created, param count: { sum ([m .numel () for m in model .parameters ()])} ' )
187191
@@ -221,11 +225,12 @@ def main():
221225 if test_time_pool :
222226 data_config ['crop_pct' ] = 1.0
223227
228+ workers = 1 if 'tfds' in args .dataset or 'wds' in args .dataset else args .workers
224229 loader = create_loader (
225230 dataset ,
226231 batch_size = args .batch_size ,
227232 use_prefetcher = True ,
228- num_workers = args . workers ,
233+ num_workers = workers ,
229234 ** data_config ,
230235 )
231236
@@ -234,7 +239,7 @@ def main():
234239 end = time .time ()
235240 all_indices = []
236241 all_outputs = []
237- use_probs = args .outputs_type == 'prob'
242+ use_probs = args .output_type == 'prob'
238243 with torch .no_grad ():
239244 for batch_idx , (input , _ ) in enumerate (loader ):
240245
@@ -262,52 +267,53 @@ def main():
262267 all_outputs = np .concatenate (all_outputs , axis = 0 ).astype (np .float32 )
263268 filenames = loader .dataset .filenames (basename = not args .fullname )
264269
265- outputs_name = args .outputs_name or ('prob' if use_probs else 'logit' )
266- data_dict = {'filename' : filenames }
267- if args .separate_columns and all_outputs .shape [- 1 ] > 1 :
270+ output_col = args .output_col or ('prob' if use_probs else 'logit' )
271+ data_dict = {args . filename_col : filenames }
272+ if args .results_separate_col and all_outputs .shape [- 1 ] > 1 :
268273 if all_indices is not None :
269274 for i in range (all_indices .shape [- 1 ]):
270- data_dict [f'{ args .indices_name } _{ i } ' ] = all_indices [:, i ]
275+ data_dict [f'{ args .index_col } _{ i } ' ] = all_indices [:, i ]
271276 for i in range (all_outputs .shape [- 1 ]):
272- data_dict [f'{ outputs_name } _{ i } ' ] = all_outputs [:, i ]
277+ data_dict [f'{ output_col } _{ i } ' ] = all_outputs [:, i ]
273278 else :
274279 if all_indices is not None :
275280 if all_indices .shape [- 1 ] == 1 :
276281 all_indices = all_indices .squeeze (- 1 )
277- data_dict [args .indices_name ] = list (all_indices )
282+ data_dict [args .index_col ] = list (all_indices )
278283 if all_outputs .shape [- 1 ] == 1 :
279284 all_outputs = all_outputs .squeeze (- 1 )
280- data_dict [outputs_name ] = list (all_outputs )
285+ data_dict [output_col ] = list (all_outputs )
281286
282287 df = pd .DataFrame (data = data_dict )
283288
284289 results_filename = args .results_file
285- needs_ext = False
286- if not results_filename :
290+ if results_filename :
291+ filename_no_ext , ext = os .path .splitext (results_filename )[- 1 ]
292+ if ext and ext in _FMT_EXT .values ():
293+ # if filename provided with one of expected ext,
294+ # remove it as it will be added back
295+ results_filename = filename_no_ext
296+ else :
287297 # base default filename on model name + img-size
288298 img_size = data_config ["input_size" ][1 ]
289299 results_filename = f'{ args .model } -{ img_size } '
290- needs_ext = True
291300
292301 if args .results_dir :
293302 results_filename = os .path .join (args .results_dir , results_filename )
294303
295- if args .results_format == 'parquet' :
296- if needs_ext :
297- results_filename += '.parquet'
298- df = df .set_index ('filename' )
299- df .to_parquet (results_filename )
300- elif args .results_format == 'json' :
301- if needs_ext :
302- results_filename += '.json'
304+ for fmt in args .results_format :
305+ save_results (df , results_filename , fmt )
306+
307+
308+ def save_results (df , results_filename , results_format = 'csv' , filename_col = 'filename' ):
309+ results_filename += _FMT_EXT [results_format ]
310+ if results_format == 'parquet' :
311+ df .set_index (filename_col ).to_parquet (results_filename )
312+ elif results_format == 'json' :
303313 df .to_json (results_filename , lines = True , orient = 'records' )
304- elif args .results_format == 'json-split' :
305- if needs_ext :
306- results_filename += '.json'
314+ elif results_format == 'json-split' :
307315 df .to_json (results_filename , indent = 4 , orient = 'split' , index = False )
308316 else :
309- if needs_ext :
310- results_filename += '.csv'
311317 df .to_csv (results_filename , index = False )
312318
313319
0 commit comments