Skip to content

Commit 5bf2f4a

Browse files
committed
chore: Apply precommit tooling
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 66573fc commit 5bf2f4a

File tree

5 files changed

+278
-253
lines changed

5 files changed

+278
-253
lines changed

tools/perf/README.md

Lines changed: 129 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,129 @@
1-
# Performance Benchmarking
2-
3-
This is a comprehensive Python benchmark suite to run perf runs using different supported backends. Following backends are supported:
4-
5-
1. Torch
6-
2. Torch-TensorRT
7-
3. FX-TRT
8-
4. TensorRT
9-
10-
11-
Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.
12-
13-
## Prerequisite
14-
15-
Benchmark scripts depends on following Python packages in addition to requirements.txt packages
16-
17-
1. Torch-TensorRT
18-
2. Torch
19-
3. TensorRT
20-
21-
## Structure
22-
23-
```
24-
./
25-
├── config
26-
│ ├── vgg16_trt.yml
27-
│ └── vgg16.yml
28-
├── models
29-
├── perf_run.py
30-
├── hub.py
31-
├── custom_models.py
32-
├── requirements.txt
33-
├── benchmark.sh
34-
└── README.md
35-
```
36-
37-
38-
39-
* `config` - Directory which contains sample yaml configuration files for VGG network.
40-
* `models` - Model directory
41-
* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends
42-
* `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT
43-
* `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT)
44-
* `utils.py` - utility functions script
45-
* `benchmark.sh` - This is used for internal performance testing of VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT.
46-
47-
## Usage
48-
49-
There are two ways you can run a performance benchmark.
50-
51-
### Using YAML config files
52-
53-
To run the benchmark for a given configuration file:
54-
55-
```python
56-
python perf_run.py --config=config/vgg16.yml
57-
```
58-
59-
There are two sample configuration files added.
60-
61-
* vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT)
62-
* vgg16_trt.yml demonstrates how to use an external TensorRT serialized engine file directly.
63-
64-
65-
### Supported fields
66-
67-
| Name | Supported Values | Description |
68-
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
69-
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
70-
| input | - | Input binding names. Expected to list shapes of each input bindings |
71-
| model | - | Configure the model filename and name |
72-
| filename | - | Model file name to load from disk. |
73-
| name | - | Model name |
74-
| runtime | - | Runtime configurations |
75-
| device | 0 | Target device ID to run inference. Range depends on available GPUs |
76-
| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend |
77-
| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision |
78-
79-
Additional sample use case:
80-
81-
```
82-
backend:
83-
- torch
84-
- torch_tensorrt
85-
- tensorrt
86-
input:
87-
input0:
88-
- 3
89-
- 224
90-
- 224
91-
num_inputs: 1
92-
model:
93-
filename: model.plan
94-
name: vgg16
95-
runtime:
96-
device: 0
97-
precision:
98-
- fp32
99-
- fp16
100-
```
101-
102-
Note:
103-
104-
1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
105-
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.
106-
107-
### Using CompileSpec options via CLI
108-
109-
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
110-
111-
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt
112-
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
113-
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
114-
* `--batch_size` : Batch size
115-
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
116-
* `--device` : Device ID
117-
* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT
118-
* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
119-
* `--report` : Path of the output file where performance summary is written.
120-
121-
Eg:
122-
123-
```
124-
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
125-
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
126-
--batch_size 1 \
127-
--backends torch,torch_tensorrt,tensorrt \
128-
--report "vgg_perf_bs1.txt"
129-
```
1+
# Performance Benchmarking
2+
3+
This is a comprehensive Python benchmark suite to run perf runs using different supported backends. Following backends are supported:
4+
5+
1. Torch
6+
2. Torch-TensorRT
7+
3. FX-TRT
8+
4. TensorRT
9+
10+
11+
Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.
12+
13+
## Prerequisite
14+
15+
Benchmark scripts depends on following Python packages in addition to requirements.txt packages
16+
17+
1. Torch-TensorRT
18+
2. Torch
19+
3. TensorRT
20+
21+
## Structure
22+
23+
```
24+
./
25+
├── config
26+
│ ├── vgg16_trt.yml
27+
│ └── vgg16.yml
28+
├── models
29+
├── perf_run.py
30+
├── hub.py
31+
├── custom_models.py
32+
├── requirements.txt
33+
├── benchmark.sh
34+
└── README.md
35+
```
36+
37+
38+
39+
* `config` - Directory which contains sample yaml configuration files for VGG network.
40+
* `models` - Model directory
41+
* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends
42+
* `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT
43+
* `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT)
44+
* `utils.py` - utility functions script
45+
* `benchmark.sh` - This is used for internal performance testing of VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT.
46+
47+
## Usage
48+
49+
There are two ways you can run a performance benchmark.
50+
51+
### Using YAML config files
52+
53+
To run the benchmark for a given configuration file:
54+
55+
```python
56+
python perf_run.py --config=config/vgg16.yml
57+
```
58+
59+
There are two sample configuration files added.
60+
61+
* vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT)
62+
* vgg16_trt.yml demonstrates how to use an external TensorRT serialized engine file directly.
63+
64+
65+
### Supported fields
66+
67+
| Name | Supported Values | Description |
68+
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
69+
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
70+
| input | - | Input binding names. Expected to list shapes of each input bindings |
71+
| model | - | Configure the model filename and name |
72+
| filename | - | Model file name to load from disk. |
73+
| name | - | Model name |
74+
| runtime | - | Runtime configurations |
75+
| device | 0 | Target device ID to run inference. Range depends on available GPUs |
76+
| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend |
77+
| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision |
78+
79+
Additional sample use case:
80+
81+
```
82+
backend:
83+
- torch
84+
- torch_tensorrt
85+
- tensorrt
86+
input:
87+
input0:
88+
- 3
89+
- 224
90+
- 224
91+
num_inputs: 1
92+
model:
93+
filename: model.plan
94+
name: vgg16
95+
runtime:
96+
device: 0
97+
precision:
98+
- fp32
99+
- fp16
100+
```
101+
102+
Note:
103+
104+
1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
105+
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.
106+
107+
### Using CompileSpec options via CLI
108+
109+
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
110+
111+
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt
112+
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
113+
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
114+
* `--batch_size` : Batch size
115+
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
116+
* `--device` : Device ID
117+
* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT
118+
* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
119+
* `--report` : Path of the output file where performance summary is written.
120+
121+
Eg:
122+
123+
```
124+
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
125+
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
126+
--batch_size 1 \
127+
--backends torch,torch_tensorrt,tensorrt \
128+
--report "vgg_perf_bs1.txt"
129+
```

tools/perf/custom_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import BertModel, BertTokenizer, BertConfig
44
import torch.nn.functional as F
55

6+
67
def BertModule():
78
model_name = "bert-base-uncased"
89
enc = BertTokenizer.from_pretrained(model_name)

tools/perf/hub.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,21 @@
1717
raise Exception("No GPU found. Please check if installed torch version is compatible with CUDA version")
1818

1919
# Downloads all model files again if manifest file is not present
20-
MANIFEST_FILE = 'model_manifest.json'
20+
MANIFEST_FILE = "model_manifest.json"
2121

2222
BENCHMARK_MODELS = {
23-
"vgg16": {
24-
"model": models.vgg16(weights=None),
25-
"path": "script"
26-
},
27-
"resnet50": {
28-
"model": models.resnet50(weights=None),
29-
"path": "script"
30-
},
31-
"efficientnet_b0": {
32-
"model": timm.create_model('efficientnet_b0', pretrained=True),
33-
"path": "script"
34-
},
35-
"vit": {
36-
"model": timm.create_model('vit_base_patch16_224', pretrained=True),
37-
"path": "script"
38-
},
39-
"bert_base_uncased": {
40-
"model": cm.BertModule(),
41-
"path": "trace"
42-
},
23+
"vgg16": {"model": models.vgg16(weights=None), "path": "script"},
24+
"resnet50": {"model": models.resnet50(weights=None), "path": "script"},
25+
"efficientnet_b0": {"model": timm.create_model("efficientnet_b0", pretrained=True), "path": "script"},
26+
"vit": {"model": timm.create_model("vit_base_patch16_224", pretrained=True), "path": "script"},
27+
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
4328
}
4429

4530

4631
def get(n, m, manifest):
4732
print("Downloading {}".format(n))
48-
traced_filename = "models/" + n + '_traced.jit.pt'
49-
script_filename = "models/" + n + '_scripted.jit.pt'
33+
traced_filename = "models/" + n + "_traced.jit.pt"
34+
script_filename = "models/" + n + "_scripted.jit.pt"
5035
x = torch.ones((1, 3, 300, 300)).cuda()
5136
if n == "bert-base-uncased":
5237
traced_model = m["model"]
@@ -80,9 +65,11 @@ def download_models(version_matches, manifest):
8065
scripted_filename = "models/" + n + "_scripted.jit.pt"
8166
traced_filename = "models/" + n + "_traced.jit.pt"
8267
# Check if model file exists on disk
83-
if (m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename)) or \
84-
(m["path"] == "script" and os.path.exists(scripted_filename)) or \
85-
(m["path"] == "trace" and os.path.exists(traced_filename)):
68+
if (
69+
(m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename))
70+
or (m["path"] == "script" and os.path.exists(scripted_filename))
71+
or (m["path"] == "trace" and os.path.exists(traced_filename))
72+
):
8673
print("Skipping {} ".format(n))
8774
continue
8875
manifest = get(n, m, manifest)
@@ -98,27 +85,31 @@ def main():
9885
manifest = {"version": torch_version}
9986

10087
# Creating an empty manifest file for overwriting post setup
101-
os.system('touch {}'.format(MANIFEST_FILE))
88+
os.system("touch {}".format(MANIFEST_FILE))
10289
else:
10390
manifest_exists = True
10491

10592
# Load manifest if already exists
106-
with open(MANIFEST_FILE, 'r') as f:
93+
with open(MANIFEST_FILE, "r") as f:
10794
manifest = json.load(f)
108-
if manifest['version'] == torch_version:
95+
if manifest["version"] == torch_version:
10996
version_matches = True
11097
else:
111-
print("Torch version: {} mismatches \
98+
print(
99+
"Torch version: {} mismatches \
112100
with manifest's version: {}. Re-downloading \
113-
all models".format(torch_version, manifest['version']))
101+
all models".format(
102+
torch_version, manifest["version"]
103+
)
104+
)
114105

115106
# Overwrite the manifest version as current torch version
116-
manifest['version'] = torch_version
107+
manifest["version"] = torch_version
117108

118109
download_models(version_matches, manifest)
119110

120111
# Write updated manifest file to disk
121-
with open(MANIFEST_FILE, 'r+') as f:
112+
with open(MANIFEST_FILE, "r+") as f:
122113
data = f.read()
123114
f.seek(0)
124115
record = json.dumps(manifest)

0 commit comments

Comments
 (0)