99import time
1010import argparse
1111import logging
12+ from contextlib import suppress
13+ from functools import partial
14+
1215import numpy as np
16+ import pandas as pd
1317import torch
1418
15- from timm .models import create_model , apply_test_time_pool
16- from timm .data import ImageDataset , create_loader , resolve_data_config
17- from timm .utils import AverageMeter , setup_default_logging
19+ from timm .models import create_model , apply_test_time_pool , load_checkpoint
20+ from timm .data import create_dataset , create_loader , resolve_data_config
21+ from timm .utils import AverageMeter , setup_default_logging , set_jit_fuser
22+
23+
24+
25+ try :
26+ from apex import amp
27+ has_apex = True
28+ except ImportError :
29+ has_apex = False
30+
31+ has_native_amp = False
32+ try :
33+ if getattr (torch .cuda .amp , 'autocast' ) is not None :
34+ has_native_amp = True
35+ except AttributeError :
36+ pass
37+
38+ try :
39+ from functorch .compile import memory_efficient_fusion
40+ has_functorch = True
41+ except ImportError as e :
42+ has_functorch = False
43+
44+ try :
45+ import torch ._dynamo
46+ has_dynamo = True
47+ except ImportError :
48+ has_dynamo = False
49+
1850
1951torch .backends .cudnn .benchmark = True
2052_logger = logging .getLogger ('inference' )
2355parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Inference' )
2456parser .add_argument ('data' , metavar = 'DIR' ,
2557 help = 'path to dataset' )
26- parser .add_argument ('--output_dir' , metavar = 'DIR' , default = './' ,
27- help = 'path to output files' )
58+ parser .add_argument ('--dataset' , '-d' , metavar = 'NAME' , default = '' ,
59+ help = 'dataset type (default: ImageFolder/ImageTar if empty)' )
60+ parser .add_argument ('--split' , metavar = 'NAME' , default = 'validation' ,
61+ help = 'dataset split (default: validation)' )
2862parser .add_argument ('--model' , '-m' , metavar = 'MODEL' , default = 'dpn92' ,
2963 help = 'model architecture (default: dpn92)' )
3064parser .add_argument ('-j' , '--workers' , default = 2 , type = int , metavar = 'N' ,
3165 help = 'number of data loading workers (default: 2)' )
3266parser .add_argument ('-b' , '--batch-size' , default = 256 , type = int ,
3367 metavar = 'N' , help = 'mini-batch size (default: 256)' )
3468parser .add_argument ('--img-size' , default = None , type = int ,
35- metavar = 'N' , help = 'Input image dimension' )
69+ metavar = 'N' , help = 'Input image dimension, uses model default if empty ' )
3670parser .add_argument ('--input-size' , default = None , nargs = 3 , type = int ,
3771 metavar = 'N N N' , help = 'Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' )
72+ parser .add_argument ('--use-train-size' , action = 'store_true' , default = False ,
73+ help = 'force use of train input size, even when test size is specified in pretrained cfg' )
74+ parser .add_argument ('--crop-pct' , default = None , type = float ,
75+ metavar = 'N' , help = 'Input image center crop pct' )
76+ parser .add_argument ('--crop-mode' , default = None , type = str ,
77+ metavar = 'N' , help = 'Input image crop mode (squash, border, center). Model default if None.' )
3878parser .add_argument ('--mean' , type = float , nargs = '+' , default = None , metavar = 'MEAN' ,
3979 help = 'Override mean pixel value of dataset' )
40- parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
80+ parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
4181 help = 'Override std deviation of of dataset' )
4282parser .add_argument ('--interpolation' , default = '' , type = str , metavar = 'NAME' ,
4383 help = 'Image resize interpolation type (overrides model)' )
44- parser .add_argument ('--num-classes' , type = int , default = 1000 ,
84+ parser .add_argument ('--num-classes' , type = int , default = None ,
4585 help = 'Number classes in dataset' )
86+ parser .add_argument ('--class-map' , default = '' , type = str , metavar = 'FILENAME' ,
87+ help = 'path to class to idx mapping file (default: "")' )
4688parser .add_argument ('--log-freq' , default = 10 , type = int ,
4789 metavar = 'N' , help = 'batch logging frequency (default: 10)' )
4890parser .add_argument ('--checkpoint' , default = '' , type = str , metavar = 'PATH' ,
5193 help = 'use pre-trained model' )
5294parser .add_argument ('--num-gpu' , type = int , default = 1 ,
5395 help = 'Number of GPUS to use' )
54- parser .add_argument ('--no-test-pool' , dest = 'no_test_pool' , action = 'store_true' ,
55- help = 'disable test time pool' )
56- parser .add_argument ('--topk' , default = 5 , type = int ,
96+ parser .add_argument ('--test-pool' , dest = 'test_pool' , action = 'store_true' ,
97+ help = 'enable test time pool' )
98+ parser .add_argument ('--channels-last' , action = 'store_true' , default = False ,
99+ help = 'Use channels_last memory layout' )
100+ parser .add_argument ('--device' , default = 'cuda' , type = str ,
101+ help = "Device (accelerator) to use." )
102+ parser .add_argument ('--amp' , action = 'store_true' , default = False ,
103+ help = 'use Native AMP for mixed precision training' )
104+ parser .add_argument ('--amp-dtype' , default = 'float16' , type = str ,
105+ help = 'lower precision AMP dtype (default: float16)' )
106+ parser .add_argument ('--use-ema' , dest = 'use_ema' , action = 'store_true' ,
107+ help = 'use ema version of weights if present' )
108+ parser .add_argument ('--fuser' , default = '' , type = str ,
109+ help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
110+ parser .add_argument ('--dynamo-backend' , default = None , type = str ,
111+ help = "Select dynamo backend. Default: None" )
112+
113+ scripting_group = parser .add_mutually_exclusive_group ()
114+ scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
115+ help = 'torch.jit.script the full model' )
116+ scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
117+ help = "Enable AOT Autograd support." )
118+ scripting_group .add_argument ('--dynamo' , default = False , action = 'store_true' ,
119+ help = "Enable Dynamo optimization." )
120+
121+ parser .add_argument ('--results-dir' ,type = str , default = None ,
122+ help = 'folder for output results' )
123+ parser .add_argument ('--results-file' , type = str , default = None ,
124+ help = 'results filename (relative to results-dir)' )
125+ parser .add_argument ('--results-format' , type = str , default = 'csv' ,
126+ help = 'results format (one of "csv", "json", "json-split", "parquet")' )
127+ parser .add_argument ('--topk' , default = 1 , type = int ,
57128 metavar = 'N' , help = 'Top-k to output to CSV' )
129+ parser .add_argument ('--fullname' , action = 'store_true' , default = False ,
130+ help = 'use full sample name in output (not just basename).' )
131+ parser .add_argument ('--indices-name' , default = 'index' ,
132+ help = 'name for output indices column(s)' )
133+ parser .add_argument ('--outputs-name' , default = None ,
134+ help = 'name for logit/probs output column(s)' )
135+ parser .add_argument ('--outputs-type' , default = 'prob' ,
136+ help = 'output type colum ("prob" for probabilities, "logit" for raw logits)' )
137+ parser .add_argument ('--separate-columns' , action = 'store_true' , default = False ,
138+ help = 'separate output columns per result index.' )
139+ parser .add_argument ('--exclude-outputs' , action = 'store_true' , default = False ,
140+ help = 'exclude logits/probs from results, just indices. topk must be set !=0.' )
58141
59142
60143def main ():
@@ -63,48 +146,109 @@ def main():
63146 # might as well try to do something useful...
64147 args .pretrained = args .pretrained or not args .checkpoint
65148
149+ if torch .cuda .is_available ():
150+ torch .backends .cuda .matmul .allow_tf32 = True
151+ torch .backends .cudnn .benchmark = True
152+
153+ device = torch .device (args .device )
154+
155+ # resolve AMP arguments based on PyTorch / Apex availability
156+ use_amp = None
157+ amp_autocast = suppress
158+ if args .amp :
159+ assert has_native_amp , 'Please update PyTorch to a version with native AMP (or use APEX).'
160+ assert args .amp_dtype in ('float16' , 'bfloat16' )
161+ amp_dtype = torch .bfloat16 if args .amp_dtype == 'bfloat16' else torch .float16
162+ amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
163+ _logger .info ('Running inference in mixed precision with native PyTorch AMP.' )
164+ else :
165+ _logger .info ('Running inference in float32. AMP not enabled.' )
166+
167+ if args .fuser :
168+ set_jit_fuser (args .fuser )
169+
66170 # create model
67171 model = create_model (
68172 args .model ,
69173 num_classes = args .num_classes ,
70174 in_chans = 3 ,
71175 pretrained = args .pretrained ,
72- checkpoint_path = args .checkpoint )
176+ checkpoint_path = args .checkpoint ,
177+ )
178+ if args .num_classes is None :
179+ assert hasattr (model , 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.'
180+ args .num_classes = model .num_classes
181+
182+ if args .checkpoint :
183+ load_checkpoint (model , args .checkpoint , args .use_ema )
184+
185+ _logger .info (
186+ f'Model { args .model } created, param count: { sum ([m .numel () for m in model .parameters ()])} ' )
73187
74- _logger .info ('Model %s created, param count: %d' %
75- (args .model , sum ([m .numel () for m in model .parameters ()])))
188+ data_config = resolve_data_config (vars (args ), model = model )
189+ test_time_pool = False
190+ if args .test_pool :
191+ model , test_time_pool = apply_test_time_pool (model , data_config )
76192
77- config = resolve_data_config (vars (args ), model = model )
78- model , test_time_pool = (model , False ) if args .no_test_pool else apply_test_time_pool (model , config )
193+ model = model .to (device )
194+ model .eval ()
195+ if args .channels_last :
196+ model = model .to (memory_format = torch .channels_last )
197+
198+ if args .torchscript :
199+ model = torch .jit .script (model )
200+ elif args .aot_autograd :
201+ assert has_functorch , "functorch is needed for --aot-autograd"
202+ model = memory_efficient_fusion (model )
203+ elif args .dynamo :
204+ assert has_dynamo , "torch._dynamo is needed for --dynamo"
205+ torch ._dynamo .reset ()
206+ if args .dynamo_backend is not None :
207+ model = torch ._dynamo .optimize (args .dynamo_backend )(model )
208+ else :
209+ model = torch ._dynamo .optimize ()(model )
79210
80211 if args .num_gpu > 1 :
81- model = torch .nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
82- else :
83- model = model .cuda ()
212+ model = torch .nn .DataParallel (model , device_ids = list (range (args .num_gpu )))
213+
214+ dataset = create_dataset (
215+ root = args .data ,
216+ name = args .dataset ,
217+ split = args .split ,
218+ class_map = args .class_map ,
219+ )
220+
221+ if test_time_pool :
222+ data_config ['crop_pct' ] = 1.0
84223
85224 loader = create_loader (
86- ImageDataset (args .data ),
87- input_size = config ['input_size' ],
225+ dataset ,
88226 batch_size = args .batch_size ,
89227 use_prefetcher = True ,
90- interpolation = config ['interpolation' ],
91- mean = config ['mean' ],
92- std = config ['std' ],
93228 num_workers = args .workers ,
94- crop_pct = 1.0 if test_time_pool else config ['crop_pct' ])
229+ ** data_config ,
230+ )
95231
96- model .eval ()
97-
98- k = min (args .topk , args .num_classes )
232+ top_k = min (args .topk , args .num_classes )
99233 batch_time = AverageMeter ()
100234 end = time .time ()
101- topk_ids = []
235+ all_indices = []
236+ all_outputs = []
237+ use_probs = args .outputs_type == 'prob'
102238 with torch .no_grad ():
103239 for batch_idx , (input , _ ) in enumerate (loader ):
104- input = input .cuda ()
105- labels = model (input )
106- topk = labels .topk (k )[1 ]
107- topk_ids .append (topk .cpu ().numpy ())
240+
241+ with amp_autocast ():
242+ output = model (input )
243+
244+ if use_probs :
245+ output = output .softmax (- 1 )
246+
247+ if top_k :
248+ output , indices = output .topk (top_k )
249+ all_indices .append (indices .cpu ().numpy ())
250+
251+ all_outputs .append (output .cpu ().numpy ())
108252
109253 # measure elapsed time
110254 batch_time .update (time .time () - end )
@@ -114,13 +258,57 @@ def main():
114258 _logger .info ('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format (
115259 batch_idx , len (loader ), batch_time = batch_time ))
116260
117- topk_ids = np .concatenate (topk_ids , axis = 0 )
261+ all_indices = np .concatenate (all_indices , axis = 0 ) if all_indices else None
262+ all_outputs = np .concatenate (all_outputs , axis = 0 ).astype (np .float32 )
263+ filenames = loader .dataset .filenames (basename = not args .fullname )
264+
265+ outputs_name = args .outputs_name or ('prob' if use_probs else 'logit' )
266+ data_dict = {'filename' : filenames }
267+ if args .separate_columns and all_outputs .shape [- 1 ] > 1 :
268+ if all_indices is not None :
269+ for i in range (all_indices .shape [- 1 ]):
270+ data_dict [f'{ args .indices_name } _{ i } ' ] = all_indices [:, i ]
271+ for i in range (all_outputs .shape [- 1 ]):
272+ data_dict [f'{ outputs_name } _{ i } ' ] = all_outputs [:, i ]
273+ else :
274+ if all_indices is not None :
275+ if all_indices .shape [- 1 ] == 1 :
276+ all_indices = all_indices .squeeze (- 1 )
277+ data_dict [args .indices_name ] = list (all_indices )
278+ if all_outputs .shape [- 1 ] == 1 :
279+ all_outputs = all_outputs .squeeze (- 1 )
280+ data_dict [outputs_name ] = list (all_outputs )
281+
282+ df = pd .DataFrame (data = data_dict )
283+
284+ results_filename = args .results_file
285+ needs_ext = False
286+ if not results_filename :
287+ # base default filename on model name + img-size
288+ img_size = data_config ["input_size" ][1 ]
289+ results_filename = f'{ args .model } -{ img_size } '
290+ needs_ext = True
118291
119- with open (os .path .join (args .output_dir , './topk_ids.csv' ), 'w' ) as out_file :
120- filenames = loader .dataset .filenames (basename = True )
121- for filename , label in zip (filenames , topk_ids ):
122- out_file .write ('{0},{1}\n ' .format (
123- filename , ',' .join ([ str (v ) for v in label ])))
292+ if args .results_dir :
293+ results_filename = os .path .join (args .results_dir , results_filename )
294+
295+ if args .results_format == 'parquet' :
296+ if needs_ext :
297+ results_filename += '.parquet'
298+ df = df .set_index ('filename' )
299+ df .to_parquet (results_filename )
300+ elif args .results_format == 'json' :
301+ if needs_ext :
302+ results_filename += '.json'
303+ df .to_json (results_filename , lines = True , orient = 'records' )
304+ elif args .results_format == 'json-split' :
305+ if needs_ext :
306+ results_filename += '.json'
307+ df .to_json (results_filename , indent = 4 , orient = 'split' , index = False )
308+ else :
309+ if needs_ext :
310+ results_filename += '.csv'
311+ df .to_csv (results_filename , index = False )
124312
125313
126314if __name__ == '__main__' :
0 commit comments