Skip to content

Commit 77543a0

Browse files
committed
chore: Linter fixes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2186177 commit 77543a0

File tree

3 files changed

+147
-30
lines changed

3 files changed

+147
-30
lines changed

tools/perf/hub.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,24 @@
1414

1515
# Detect case of no GPU before deserialization of models on GPU
1616
if not torch.cuda.is_available():
17-
raise Exception("No GPU found. Please check if installed torch version is compatible with CUDA version")
17+
raise Exception(
18+
"No GPU found. Please check if installed torch version is compatible with CUDA version"
19+
)
1820

1921
# Downloads all model files again if manifest file is not present
2022
MANIFEST_FILE = "model_manifest.json"
2123

2224
BENCHMARK_MODELS = {
2325
"vgg16": {"model": models.vgg16(weights=None), "path": "script"},
2426
"resnet50": {"model": models.resnet50(weights=None), "path": "script"},
25-
"efficientnet_b0": {"model": timm.create_model("efficientnet_b0", pretrained=True), "path": "script"},
26-
"vit": {"model": timm.create_model("vit_base_patch16_224", pretrained=True), "path": "script"},
27+
"efficientnet_b0": {
28+
"model": timm.create_model("efficientnet_b0", pretrained=True),
29+
"path": "script",
30+
},
31+
"vit": {
32+
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
33+
"path": "script",
34+
},
2735
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
2836
}
2937

@@ -66,7 +74,11 @@ def download_models(version_matches, manifest):
6674
traced_filename = "models/" + n + "_traced.jit.pt"
6775
# Check if model file exists on disk
6876
if (
69-
(m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename))
77+
(
78+
m["path"] == "both"
79+
and os.path.exists(scripted_filename)
80+
and os.path.exists(traced_filename)
81+
)
7082
or (m["path"] == "script" and os.path.exists(scripted_filename))
7183
or (m["path"] == "trace" and os.path.exists(traced_filename))
7284
):

tools/perf/perf_run.py

Lines changed: 114 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from torch_tensorrt.fx.utils import LowerPrecision
2020

2121
import tensorrt as trt
22-
from utils import parse_inputs, parse_backends, precision_to_dtype, parse_precisions, BENCHMARK_MODELS
22+
from utils import (
23+
parse_inputs,
24+
parse_backends,
25+
precision_to_dtype,
26+
parse_precisions,
27+
BENCHMARK_MODELS,
28+
)
2329

2430
WARMUP_ITER = 10
2531
results = []
@@ -45,7 +51,8 @@ def get(self, key, default_value=None):
4551
if not key in self.params:
4652
if not default_value:
4753
raise ValueError(
48-
"Key {} is not present and default_value is not configured. Please run it with default value", key
54+
"Key {} is not present and default_value is not configured. Please run it with default value",
55+
key,
4956
)
5057
self.params[key] = default_value
5158
return self.params[key]
@@ -77,8 +84,15 @@ def run_torch(model, input_tensors, params, precision, batch_size):
7784

7885

7986
# Runs inference using Torch-TensorRT backend
80-
def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size):
81-
print("Running Torch-TensorRT for precision: ", precision, " batch_size : ", batch_size)
87+
def run_torch_tensorrt(
88+
model, input_tensors, params, precision, truncate_long_and_double, batch_size
89+
):
90+
print(
91+
"Running Torch-TensorRT for precision: ",
92+
precision,
93+
" batch_size : ",
94+
batch_size,
95+
)
8296
# Compiling Torch-TensorRT model
8397
compile_settings = {
8498
"inputs": input_tensors,
@@ -176,7 +190,13 @@ def torch_device_from_trt(device):
176190

177191

178192
def run_tensorrt(
179-
model, input_tensors, params, precision, truncate_long_and_double=False, is_trt_engine=False, batch_size=1
193+
model,
194+
input_tensors,
195+
params,
196+
precision,
197+
truncate_long_and_double=False,
198+
is_trt_engine=False,
199+
batch_size=1,
180200
):
181201
engine = None
182202

@@ -237,7 +257,14 @@ def run_tensorrt(
237257

238258
# Deploys inference run for different backend configurations
239259
def run(
240-
model, backends, input_tensors, params, precision, truncate_long_and_double=False, batch_size=1, is_trt_engine=False
260+
model,
261+
backends,
262+
input_tensors,
263+
params,
264+
precision,
265+
truncate_long_and_double=False,
266+
batch_size=1,
267+
is_trt_engine=False,
241268
):
242269
for backend in backends:
243270
if precision == "int8":
@@ -257,20 +284,50 @@ def run(
257284

258285
if backend == "all":
259286
run_torch(model, input_tensors, params, precision, batch_size)
260-
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
261-
run_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, is_trt_engine, batch_size)
287+
run_torch_tensorrt(
288+
model,
289+
input_tensors,
290+
params,
291+
precision,
292+
truncate_long_and_double,
293+
batch_size,
294+
)
295+
run_tensorrt(
296+
model,
297+
input_tensors,
298+
params,
299+
precision,
300+
truncate_long_and_double,
301+
is_trt_engine,
302+
batch_size,
303+
)
262304

263305
elif backend == "torch":
264306
run_torch(model, input_tensors, params, precision, batch_size)
265307

266308
elif backend == "torch_tensorrt":
267-
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
309+
run_torch_tensorrt(
310+
model,
311+
input_tensors,
312+
params,
313+
precision,
314+
truncate_long_and_double,
315+
batch_size,
316+
)
268317

269318
elif backend == "fx2trt":
270319
run_fx2trt(model, input_tensors, params, precision, batch_size)
271320

272321
elif backend == "tensorrt":
273-
run_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, is_trt_engine, batch_size)
322+
run_tensorrt(
323+
model,
324+
input_tensors,
325+
params,
326+
precision,
327+
truncate_long_and_double,
328+
is_trt_engine,
329+
batch_size,
330+
)
274331

275332

276333
# Generate report
@@ -291,8 +348,8 @@ def recordStats(backend, timings, precision, batch_size=1):
291348
"Batch size": batch_size,
292349
"Median(FPS)": speed_med,
293350
"Mean(FPS)": speed_mean,
294-
"Median-Latency(ms)": time_med*1000,
295-
"Mean-Latency(ms)": time_mean*1000,
351+
"Median-Latency(ms)": time_med * 1000,
352+
"Mean-Latency(ms)": time_mean * 1000,
296353
}
297354
results.append(stats)
298355

@@ -330,32 +387,44 @@ def load_model(params):
330387
)
331388
# The following options are manual user provided settings
332389
arg_parser.add_argument(
333-
"--backends", type=str, help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt"
390+
"--backends",
391+
type=str,
392+
help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt",
334393
)
335394
arg_parser.add_argument("--model", type=str, help="Name of the model file")
336395
arg_parser.add_argument(
337396
"--inputs",
338397
type=str,
339398
help="List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT",
340399
)
341-
arg_parser.add_argument("--batch_size", type=int, default=1, help="Batch size to build and run")
400+
arg_parser.add_argument(
401+
"--batch_size", type=int, default=1, help="Batch size to build and run"
402+
)
342403
arg_parser.add_argument(
343404
"--precision",
344405
default="fp32",
345406
type=str,
346407
help="Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16",
347408
)
348-
arg_parser.add_argument("--calibration_cache", type=str, help="Name of the calibration cache file")
409+
arg_parser.add_argument(
410+
"--calibration_cache", type=str, help="Name of the calibration cache file"
411+
)
349412
arg_parser.add_argument("--device", type=int, help="device id")
350413
arg_parser.add_argument(
351-
"--truncate", action="store_true", help="Truncate long and double weights in the network in Torch-TensorRT"
414+
"--truncate",
415+
action="store_true",
416+
help="Truncate long and double weights in the network in Torch-TensorRT",
352417
)
353418
arg_parser.add_argument(
354419
"--is_trt_engine",
355420
action="store_true",
356421
help="Boolean flag to determine if the user provided model is a TRT engine or not",
357422
)
358-
arg_parser.add_argument("--report", type=str, help="Path of the output file where performance summary is written.")
423+
arg_parser.add_argument(
424+
"--report",
425+
type=str,
426+
help="Path of the output file where performance summary is written.",
427+
)
359428
args = arg_parser.parse_args()
360429

361430
cudnn.benchmark = True
@@ -372,15 +441,22 @@ def load_model(params):
372441
torch.cuda.set_device(params.get("runtime").get("device", 0))
373442

374443
num_input = params.get("input").get("num_inputs")
375-
truncate_long_and_double = params.get("runtime").get("truncate_long_and_double", False)
444+
truncate_long_and_double = params.get("runtime").get(
445+
"truncate_long_and_double", False
446+
)
376447
batch_size = params.get("input").get("batch_size", 1)
377448
for precision in params.get("runtime").get("precision", "fp32"):
378449
input_tensors = []
379450
num_input = params.get("input").get("num_inputs", 1)
380451
for i in range(num_input):
381452
inp_tensor = params.get("input").get("input" + str(i))
382453
input_tensors.append(
383-
torch.randint(0, 2, tuple(d for d in inp_tensor), dtype=precision_to_dtype(precision)).cuda()
454+
torch.randint(
455+
0,
456+
2,
457+
tuple(d for d in inp_tensor),
458+
dtype=precision_to_dtype(precision),
459+
).cuda()
384460
)
385461

386462
if is_trt_engine:
@@ -395,7 +471,14 @@ def load_model(params):
395471
backends = params.get("backend")
396472
# Run inference
397473
status = run(
398-
model, backends, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine
474+
model,
475+
backends,
476+
input_tensors,
477+
params,
478+
precision,
479+
truncate_long_and_double,
480+
batch_size,
481+
is_trt_engine,
399482
)
400483
else:
401484
params = vars(args)
@@ -417,12 +500,21 @@ def load_model(params):
417500
precisions = parse_precisions(params["precision"])
418501

419502
for precision in precisions:
420-
input_tensors = parse_inputs(params["inputs"], precision_to_dtype(precision))
503+
input_tensors = parse_inputs(
504+
params["inputs"], precision_to_dtype(precision)
505+
)
421506
if not is_trt_engine and (precision == "fp16" or precision == "half"):
422507
# If model is TensorRT serialized engine then model.half will report failure
423508
model = model.half()
424509
status = run(
425-
model, backends, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine
510+
model,
511+
backends,
512+
input_tensors,
513+
params,
514+
precision,
515+
truncate_long_and_double,
516+
batch_size,
517+
is_trt_engine,
426518
)
427519

428520
# Generate report

tools/perf/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@
66

77
BENCHMARK_MODELS = {
88
"vgg16": {"model": models.vgg16(pretrained=True), "path": "script"},
9-
"resnet50": {"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True), "path": "script"},
10-
"efficientnet_b0": {"model": timm.create_model("efficientnet_b0", pretrained=True), "path": "script"},
11-
"vit": {"model": timm.create_model("vit_base_patch16_224", pretrained=True), "path": "script"},
9+
"resnet50": {
10+
"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True),
11+
"path": "script",
12+
},
13+
"efficientnet_b0": {
14+
"model": timm.create_model("efficientnet_b0", pretrained=True),
15+
"path": "script",
16+
},
17+
"vit": {
18+
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
19+
"path": "script",
20+
},
1221
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
1322
}
1423

@@ -32,7 +41,11 @@ def parse_inputs(user_inputs, dtype):
3241
for input in parsed_inputs:
3342
input_shape = []
3443
input_shape_and_dtype = input.split("@")
35-
dtype = precision_to_dtype(input_shape_and_dtype[1]) if len(input_shape_and_dtype) == 2 else dtype
44+
dtype = (
45+
precision_to_dtype(input_shape_and_dtype[1])
46+
if len(input_shape_and_dtype) == 2
47+
else dtype
48+
)
3649
for input_dim in input_shape_and_dtype[0][1:-1].split(","):
3750
input_shape.append(int(input_dim))
3851
torchtrt_inputs.append(torch.randint(0, 5, input_shape, dtype=dtype).cuda())

0 commit comments

Comments
 (0)