19
19
from torch_tensorrt .fx .utils import LowerPrecision
20
20
21
21
import tensorrt as trt
22
- from utils import parse_inputs , parse_backends , precision_to_dtype , parse_precisions , BENCHMARK_MODELS
22
+ from utils import (
23
+ parse_inputs ,
24
+ parse_backends ,
25
+ precision_to_dtype ,
26
+ parse_precisions ,
27
+ BENCHMARK_MODELS ,
28
+ )
23
29
24
30
WARMUP_ITER = 10
25
31
results = []
@@ -45,7 +51,8 @@ def get(self, key, default_value=None):
45
51
if not key in self .params :
46
52
if not default_value :
47
53
raise ValueError (
48
- "Key {} is not present and default_value is not configured. Please run it with default value" , key
54
+ "Key {} is not present and default_value is not configured. Please run it with default value" ,
55
+ key ,
49
56
)
50
57
self .params [key ] = default_value
51
58
return self .params [key ]
@@ -77,8 +84,15 @@ def run_torch(model, input_tensors, params, precision, batch_size):
77
84
78
85
79
86
# Runs inference using Torch-TensorRT backend
80
- def run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size ):
81
- print ("Running Torch-TensorRT for precision: " , precision , " batch_size : " , batch_size )
87
+ def run_torch_tensorrt (
88
+ model , input_tensors , params , precision , truncate_long_and_double , batch_size
89
+ ):
90
+ print (
91
+ "Running Torch-TensorRT for precision: " ,
92
+ precision ,
93
+ " batch_size : " ,
94
+ batch_size ,
95
+ )
82
96
# Compiling Torch-TensorRT model
83
97
compile_settings = {
84
98
"inputs" : input_tensors ,
@@ -176,7 +190,13 @@ def torch_device_from_trt(device):
176
190
177
191
178
192
def run_tensorrt (
179
- model , input_tensors , params , precision , truncate_long_and_double = False , is_trt_engine = False , batch_size = 1
193
+ model ,
194
+ input_tensors ,
195
+ params ,
196
+ precision ,
197
+ truncate_long_and_double = False ,
198
+ is_trt_engine = False ,
199
+ batch_size = 1 ,
180
200
):
181
201
engine = None
182
202
@@ -237,7 +257,14 @@ def run_tensorrt(
237
257
238
258
# Deploys inference run for different backend configurations
239
259
def run (
240
- model , backends , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False
260
+ model ,
261
+ backends ,
262
+ input_tensors ,
263
+ params ,
264
+ precision ,
265
+ truncate_long_and_double = False ,
266
+ batch_size = 1 ,
267
+ is_trt_engine = False ,
241
268
):
242
269
for backend in backends :
243
270
if precision == "int8" :
@@ -257,20 +284,50 @@ def run(
257
284
258
285
if backend == "all" :
259
286
run_torch (model , input_tensors , params , precision , batch_size )
260
- run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
261
- run_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , is_trt_engine , batch_size )
287
+ run_torch_tensorrt (
288
+ model ,
289
+ input_tensors ,
290
+ params ,
291
+ precision ,
292
+ truncate_long_and_double ,
293
+ batch_size ,
294
+ )
295
+ run_tensorrt (
296
+ model ,
297
+ input_tensors ,
298
+ params ,
299
+ precision ,
300
+ truncate_long_and_double ,
301
+ is_trt_engine ,
302
+ batch_size ,
303
+ )
262
304
263
305
elif backend == "torch" :
264
306
run_torch (model , input_tensors , params , precision , batch_size )
265
307
266
308
elif backend == "torch_tensorrt" :
267
- run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
309
+ run_torch_tensorrt (
310
+ model ,
311
+ input_tensors ,
312
+ params ,
313
+ precision ,
314
+ truncate_long_and_double ,
315
+ batch_size ,
316
+ )
268
317
269
318
elif backend == "fx2trt" :
270
319
run_fx2trt (model , input_tensors , params , precision , batch_size )
271
320
272
321
elif backend == "tensorrt" :
273
- run_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , is_trt_engine , batch_size )
322
+ run_tensorrt (
323
+ model ,
324
+ input_tensors ,
325
+ params ,
326
+ precision ,
327
+ truncate_long_and_double ,
328
+ is_trt_engine ,
329
+ batch_size ,
330
+ )
274
331
275
332
276
333
# Generate report
@@ -291,8 +348,8 @@ def recordStats(backend, timings, precision, batch_size=1):
291
348
"Batch size" : batch_size ,
292
349
"Median(FPS)" : speed_med ,
293
350
"Mean(FPS)" : speed_mean ,
294
- "Median-Latency(ms)" : time_med * 1000 ,
295
- "Mean-Latency(ms)" : time_mean * 1000 ,
351
+ "Median-Latency(ms)" : time_med * 1000 ,
352
+ "Mean-Latency(ms)" : time_mean * 1000 ,
296
353
}
297
354
results .append (stats )
298
355
@@ -330,32 +387,44 @@ def load_model(params):
330
387
)
331
388
# The following options are manual user provided settings
332
389
arg_parser .add_argument (
333
- "--backends" , type = str , help = "Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt"
390
+ "--backends" ,
391
+ type = str ,
392
+ help = "Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt" ,
334
393
)
335
394
arg_parser .add_argument ("--model" , type = str , help = "Name of the model file" )
336
395
arg_parser .add_argument (
337
396
"--inputs" ,
338
397
type = str ,
339
398
help = "List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT" ,
340
399
)
341
- arg_parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Batch size to build and run" )
400
+ arg_parser .add_argument (
401
+ "--batch_size" , type = int , default = 1 , help = "Batch size to build and run"
402
+ )
342
403
arg_parser .add_argument (
343
404
"--precision" ,
344
405
default = "fp32" ,
345
406
type = str ,
346
407
help = "Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16" ,
347
408
)
348
- arg_parser .add_argument ("--calibration_cache" , type = str , help = "Name of the calibration cache file" )
409
+ arg_parser .add_argument (
410
+ "--calibration_cache" , type = str , help = "Name of the calibration cache file"
411
+ )
349
412
arg_parser .add_argument ("--device" , type = int , help = "device id" )
350
413
arg_parser .add_argument (
351
- "--truncate" , action = "store_true" , help = "Truncate long and double weights in the network in Torch-TensorRT"
414
+ "--truncate" ,
415
+ action = "store_true" ,
416
+ help = "Truncate long and double weights in the network in Torch-TensorRT" ,
352
417
)
353
418
arg_parser .add_argument (
354
419
"--is_trt_engine" ,
355
420
action = "store_true" ,
356
421
help = "Boolean flag to determine if the user provided model is a TRT engine or not" ,
357
422
)
358
- arg_parser .add_argument ("--report" , type = str , help = "Path of the output file where performance summary is written." )
423
+ arg_parser .add_argument (
424
+ "--report" ,
425
+ type = str ,
426
+ help = "Path of the output file where performance summary is written." ,
427
+ )
359
428
args = arg_parser .parse_args ()
360
429
361
430
cudnn .benchmark = True
@@ -372,15 +441,22 @@ def load_model(params):
372
441
torch .cuda .set_device (params .get ("runtime" ).get ("device" , 0 ))
373
442
374
443
num_input = params .get ("input" ).get ("num_inputs" )
375
- truncate_long_and_double = params .get ("runtime" ).get ("truncate_long_and_double" , False )
444
+ truncate_long_and_double = params .get ("runtime" ).get (
445
+ "truncate_long_and_double" , False
446
+ )
376
447
batch_size = params .get ("input" ).get ("batch_size" , 1 )
377
448
for precision in params .get ("runtime" ).get ("precision" , "fp32" ):
378
449
input_tensors = []
379
450
num_input = params .get ("input" ).get ("num_inputs" , 1 )
380
451
for i in range (num_input ):
381
452
inp_tensor = params .get ("input" ).get ("input" + str (i ))
382
453
input_tensors .append (
383
- torch .randint (0 , 2 , tuple (d for d in inp_tensor ), dtype = precision_to_dtype (precision )).cuda ()
454
+ torch .randint (
455
+ 0 ,
456
+ 2 ,
457
+ tuple (d for d in inp_tensor ),
458
+ dtype = precision_to_dtype (precision ),
459
+ ).cuda ()
384
460
)
385
461
386
462
if is_trt_engine :
@@ -395,7 +471,14 @@ def load_model(params):
395
471
backends = params .get ("backend" )
396
472
# Run inference
397
473
status = run (
398
- model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine
474
+ model ,
475
+ backends ,
476
+ input_tensors ,
477
+ params ,
478
+ precision ,
479
+ truncate_long_and_double ,
480
+ batch_size ,
481
+ is_trt_engine ,
399
482
)
400
483
else :
401
484
params = vars (args )
@@ -417,12 +500,21 @@ def load_model(params):
417
500
precisions = parse_precisions (params ["precision" ])
418
501
419
502
for precision in precisions :
420
- input_tensors = parse_inputs (params ["inputs" ], precision_to_dtype (precision ))
503
+ input_tensors = parse_inputs (
504
+ params ["inputs" ], precision_to_dtype (precision )
505
+ )
421
506
if not is_trt_engine and (precision == "fp16" or precision == "half" ):
422
507
# If model is TensorRT serialized engine then model.half will report failure
423
508
model = model .half ()
424
509
status = run (
425
- model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine
510
+ model ,
511
+ backends ,
512
+ input_tensors ,
513
+ params ,
514
+ precision ,
515
+ truncate_long_and_double ,
516
+ batch_size ,
517
+ is_trt_engine ,
426
518
)
427
519
428
520
# Generate report
0 commit comments