Skip to content

Commit 05637a4

Browse files
committed
More inference script changes, arg naming, multiple output fmts at once
1 parent eceeb94 commit 05637a4

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

inference.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
has_dynamo = False
4949

5050

51+
_FMT_EXT = {
52+
'json': '.json',
53+
'json-split': '.json',
54+
'parquet': '.parquet',
55+
'csv': '.csv',
56+
}
57+
5158
torch.backends.cudnn.benchmark = True
5259
_logger = logging.getLogger('inference')
5360

@@ -103,8 +110,6 @@
103110
help='use Native AMP for mixed precision training')
104111
parser.add_argument('--amp-dtype', default='float16', type=str,
105112
help='lower precision AMP dtype (default: float16)')
106-
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
107-
help='use ema version of weights if present')
108113
parser.add_argument('--fuser', default='', type=str,
109114
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
110115
parser.add_argument('--dynamo-backend', default=None, type=str,
@@ -122,21 +127,23 @@
122127
help='folder for output results')
123128
parser.add_argument('--results-file', type=str, default=None,
124129
help='results filename (relative to results-dir)')
125-
parser.add_argument('--results-format', type=str, default='csv',
130+
parser.add_argument('--results-format', type=str, nargs='+', default=['csv'],
126131
help='results format (one of "csv", "json", "json-split", "parquet")')
132+
parser.add_argument('--results-separate-col', action='store_true', default=False,
133+
help='separate output columns per result index.')
127134
parser.add_argument('--topk', default=1, type=int,
128135
metavar='N', help='Top-k to output to CSV')
129136
parser.add_argument('--fullname', action='store_true', default=False,
130137
help='use full sample name in output (not just basename).')
131-
parser.add_argument('--indices-name', default='index',
138+
parser.add_argument('--filename-col', default='filename',
139+
help='name for filename / sample name column')
140+
parser.add_argument('--index-col', default='index',
132141
help='name for output indices column(s)')
133-
parser.add_argument('--outputs-name', default=None,
142+
parser.add_argument('--output-col', default=None,
134143
help='name for logit/probs output column(s)')
135-
parser.add_argument('--outputs-type', default='prob',
144+
parser.add_argument('--output-type', default='prob',
136145
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
137-
parser.add_argument('--separate-columns', action='store_true', default=False,
138-
help='separate output columns per result index.')
139-
parser.add_argument('--exclude-outputs', action='store_true', default=False,
146+
parser.add_argument('--exclude-output', action='store_true', default=False,
140147
help='exclude logits/probs from results, just indices. topk must be set !=0.')
141148

142149

@@ -179,9 +186,6 @@ def main():
179186
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
180187
args.num_classes = model.num_classes
181188

182-
if args.checkpoint:
183-
load_checkpoint(model, args.checkpoint, args.use_ema)
184-
185189
_logger.info(
186190
f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
187191

@@ -221,11 +225,12 @@ def main():
221225
if test_time_pool:
222226
data_config['crop_pct'] = 1.0
223227

228+
workers = 1 if 'tfds' in args.dataset or 'wds' in args.dataset else args.workers
224229
loader = create_loader(
225230
dataset,
226231
batch_size=args.batch_size,
227232
use_prefetcher=True,
228-
num_workers=args.workers,
233+
num_workers=workers,
229234
**data_config,
230235
)
231236

@@ -234,7 +239,7 @@ def main():
234239
end = time.time()
235240
all_indices = []
236241
all_outputs = []
237-
use_probs = args.outputs_type == 'prob'
242+
use_probs = args.output_type == 'prob'
238243
with torch.no_grad():
239244
for batch_idx, (input, _) in enumerate(loader):
240245

@@ -262,52 +267,53 @@ def main():
262267
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
263268
filenames = loader.dataset.filenames(basename=not args.fullname)
264269

265-
outputs_name = args.outputs_name or ('prob' if use_probs else 'logit')
266-
data_dict = {'filename': filenames}
267-
if args.separate_columns and all_outputs.shape[-1] > 1:
270+
output_col = args.output_col or ('prob' if use_probs else 'logit')
271+
data_dict = {args.filename_col: filenames}
272+
if args.results_separate_col and all_outputs.shape[-1] > 1:
268273
if all_indices is not None:
269274
for i in range(all_indices.shape[-1]):
270-
data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i]
275+
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
271276
for i in range(all_outputs.shape[-1]):
272-
data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i]
277+
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
273278
else:
274279
if all_indices is not None:
275280
if all_indices.shape[-1] == 1:
276281
all_indices = all_indices.squeeze(-1)
277-
data_dict[args.indices_name] = list(all_indices)
282+
data_dict[args.index_col] = list(all_indices)
278283
if all_outputs.shape[-1] == 1:
279284
all_outputs = all_outputs.squeeze(-1)
280-
data_dict[outputs_name] = list(all_outputs)
285+
data_dict[output_col] = list(all_outputs)
281286

282287
df = pd.DataFrame(data=data_dict)
283288

284289
results_filename = args.results_file
285-
needs_ext = False
286-
if not results_filename:
290+
if results_filename:
291+
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
292+
if ext and ext in _FMT_EXT.values():
293+
# if filename provided with one of expected ext,
294+
# remove it as it will be added back
295+
results_filename = filename_no_ext
296+
else:
287297
# base default filename on model name + img-size
288298
img_size = data_config["input_size"][1]
289299
results_filename = f'{args.model}-{img_size}'
290-
needs_ext = True
291300

292301
if args.results_dir:
293302
results_filename = os.path.join(args.results_dir, results_filename)
294303

295-
if args.results_format == 'parquet':
296-
if needs_ext:
297-
results_filename += '.parquet'
298-
df = df.set_index('filename')
299-
df.to_parquet(results_filename)
300-
elif args.results_format == 'json':
301-
if needs_ext:
302-
results_filename += '.json'
304+
for fmt in args.results_format:
305+
save_results(df, results_filename, fmt)
306+
307+
308+
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
309+
results_filename += _FMT_EXT[results_format]
310+
if results_format == 'parquet':
311+
df.set_index(filename_col).to_parquet(results_filename)
312+
elif results_format == 'json':
303313
df.to_json(results_filename, lines=True, orient='records')
304-
elif args.results_format == 'json-split':
305-
if needs_ext:
306-
results_filename += '.json'
314+
elif results_format == 'json-split':
307315
df.to_json(results_filename, indent=4, orient='split', index=False)
308316
else:
309-
if needs_ext:
310-
results_filename += '.csv'
311317
df.to_csv(results_filename, index=False)
312318

313319

0 commit comments

Comments
 (0)