1
+ from distutils .command .clean import clean
1
2
import nox
2
3
import os
3
4
8
9
# TOP_DIR
9
10
TOP_DIR = os .path .dirname (os .path .realpath (__file__ )) if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
10
11
11
- nox .options .sessions = ["developer_tests -3" ]
12
+ nox .options .sessions = ["l0_api_tests -3" ]
12
13
13
14
def install_deps (session ):
14
15
print ("Installing deps" )
@@ -30,31 +31,6 @@ def install_torch_trt(session):
30
31
session .chdir (os .path .join (TOP_DIR , "py" ))
31
32
session .run ("python" , "setup.py" , "develop" )
32
33
33
- def run_base_tests (session , use_host_env = False ):
34
- print ("Running basic tests" )
35
- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
36
- tests = [
37
- "test_api.py" ,
38
- "test_to_backend_api.py"
39
- ]
40
- for test in tests :
41
- if use_host_env :
42
- session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
43
- else :
44
- session .run_always ("python" , test )
45
-
46
-
47
- # Install the latest build of torch-tensorrt
48
- @nox .session (python = ["3" ], reuse_venv = True )
49
- def developer_tests (session ):
50
- """Basic set of tests that need to pass for code to get merged"""
51
- install_deps (session )
52
- install_torch_trt (session )
53
- download_models (session )
54
- run_base_tests (session )
55
-
56
- # Download the dataset
57
- @nox .session (python = ["3" ], reuse_venv = True )
58
34
def download_datasets (session ):
59
35
print ("Downloading dataset to path" , os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
60
36
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
@@ -68,98 +44,70 @@ def download_datasets(session):
68
44
os .path .join (TOP_DIR , 'tests/accuracy/datasets/data/cidar-10-batches-bin' ),
69
45
external = True )
70
46
71
- # Download the model
72
- @nox .session (python = ["3" ], reuse_venv = True )
73
- def download_test_models (session ):
74
- download_models (session , use_host_env = True )
75
-
76
- # Train the model
77
- @nox .session (python = ["3" ], reuse_venv = True )
78
- def train_model (session ):
47
+ def train_model (session , use_host_env = False ):
79
48
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
80
- session . run_always ( 'python' ,
81
- 'main.py ' ,
82
- '--lr' , '0.01 ' ,
83
- '--batch-size ' , '128 ' ,
84
- '--drop-ratio ' , '0.15 ' ,
85
- '--ckpt-dir ' , 'vgg16_ckpts ' ,
86
- '--epochs ' , '25 ' ,
87
- env = { 'PYTHONPATH' : PYT_PATH })
88
-
89
- # Export model
90
- session .run_always ('python' ,
49
+ if use_host_env :
50
+ session . run_always ( 'python ' ,
51
+ 'main.py ' ,
52
+ '--lr ' , '0.01 ' ,
53
+ '--batch-size ' , '128 ' ,
54
+ '--drop-ratio ' , '0.15 ' ,
55
+ '--ckpt-dir ' , 'vgg16_ckpts ' ,
56
+ '--epochs' , '25' ,
57
+ env = { 'PYTHONPATH' : PYT_PATH })
58
+
59
+ session .run_always ('python' ,
91
60
'export_ckpt.py' ,
92
- 'vgg16_ckpts/ckpt_epoch25.pth' ,
93
- env = {'PYTHONPATH' : PYT_PATH })
61
+ 'vgg16_ckpts/ckpt_epoch25.pth' )
62
+ else :
63
+ session .run_always ('python' ,
64
+ 'main.py' ,
65
+ '--lr' , '0.01' ,
66
+ '--batch-size' , '128' ,
67
+ '--drop-ratio' , '0.15' ,
68
+ '--ckpt-dir' , 'vgg16_ckpts' ,
69
+ '--epochs' , '25' )
94
70
95
- # Finetune the model
96
- @nox .session (python = ["3" ], reuse_venv = True )
97
- def finetune_model (session ):
71
+ session .run_always ('python' ,
72
+ 'export_ckpt.py' ,
73
+ 'vgg16_ckpts/ckpt_epoch25.pth' )
74
+
75
+ def finetune_model (session , use_host_env = False ):
98
76
# Install pytorch-quantization dependency
99
77
session .install ('pytorch-quantization' , '--extra-index-url' , 'https://pypi.ngc.nvidia.com' )
100
-
101
78
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
102
- session .run_always ('python' ,
103
- 'finetune_qat.py' ,
104
- '--lr' , '0.01' ,
105
- '--batch-size' , '128' ,
106
- '--drop-ratio' , '0.15' ,
107
- '--ckpt-dir' , 'vgg16_ckpts' ,
108
- '--start-from' , '25' ,
109
- '--epochs' , '26' ,
110
- env = {'PYTHONPATH' : PYT_PATH })
111
-
112
- # Export model
113
- session .run_always ('python' ,
114
- 'export_qat.py' ,
115
- 'vgg16_ckpts/ckpt_epoch26.pth' ,
116
- env = {'PYTHONPATH' : PYT_PATH })
117
-
118
- # Run PTQ tests
119
- @nox .session (python = ["3" ], reuse_venv = True )
120
- def ptq_test (session ):
121
- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
122
- session .run_always ('cp' , '-rf' ,
123
- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16.jit.pt' ),
124
- '.' ,
125
- external = True )
126
- tests = [
127
- 'test_ptq_dataloader_calibrator.py' ,
128
- 'test_ptq_to_backend.py' ,
129
- 'test_ptq_trt_calibrator.py'
130
- ]
131
- for test in tests :
132
- session .run_always ('python' , test ,
133
- env = {'PYTHONPATH' : PYT_PATH })
134
79
135
- # Run QAT tests
136
- @nox .session (python = ["3" ], reuse_venv = True )
137
- def qat_test (session ):
138
- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
139
- session .run_always ('cp' , '-rf' ,
140
- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16_qat.jit.pt' ),
141
- '.' ,
142
- external = True )
143
-
144
- session .run_always ('python' ,
145
- 'test_qat_trt_accuracy.py' ,
146
- env = {'PYTHONPATH' : PYT_PATH })
80
+ if use_host_env :
81
+ session .run_always ('python' ,
82
+ 'finetune_qat.py' ,
83
+ '--lr' , '0.01' ,
84
+ '--batch-size' , '128' ,
85
+ '--drop-ratio' , '0.15' ,
86
+ '--ckpt-dir' , 'vgg16_ckpts' ,
87
+ '--start-from' , '25' ,
88
+ '--epochs' , '26' ,
89
+ env = {'PYTHONPATH' : PYT_PATH })
147
90
148
- # Run Python API tests
149
- @nox .session (python = ["3" ], reuse_venv = True )
150
- def api_test (session ):
151
- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152
- tests = [
153
- "test_api.py" ,
154
- "test_to_backend_api.py"
155
- ]
156
- for test in tests :
91
+ # Export model
157
92
session .run_always ('python' ,
158
- test ,
93
+ 'export_qat.py' ,
94
+ 'vgg16_ckpts/ckpt_epoch26.pth' ,
159
95
env = {'PYTHONPATH' : PYT_PATH })
96
+ else :
97
+ session .run_always ('python' ,
98
+ 'finetune_qat.py' ,
99
+ '--lr' , '0.01' ,
100
+ '--batch-size' , '128' ,
101
+ '--drop-ratio' , '0.15' ,
102
+ '--ckpt-dir' , 'vgg16_ckpts' ,
103
+ '--start-from' , '25' ,
104
+ '--epochs' , '26' )
105
+
106
+ # Export model
107
+ session .run_always ('python' ,
108
+ 'export_qat.py' ,
109
+ 'vgg16_ckpts/ckpt_epoch26.pth' )
160
110
161
- # Clean up
162
- @nox .session (reuse_venv = True )
163
111
def cleanup (session ):
164
112
target = [
165
113
'examples/int8/training/vgg16/*.jit.pt' ,
@@ -173,4 +121,186 @@ def cleanup(session):
173
121
target = ' ' .join (x for x in [os .path .join (TOP_DIR , i ) for i in target ])
174
122
session .run_always ('bash' , '-c' ,
175
123
str ('rm -rf ' ) + target ,
176
- external = True )
124
+ external = True )
125
+
126
+ def run_base_tests (session , use_host_env = False ):
127
+ print ("Running basic tests" )
128
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
129
+ tests = [
130
+ "test_api.py" ,
131
+ "test_to_backend_api.py"
132
+ ]
133
+ for test in tests :
134
+ if use_host_env :
135
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
136
+ else :
137
+ session .run_always ("python" , test )
138
+
139
+ def run_accuracy_tests (session , use_host_env = False ):
140
+ print ("Running accuracy tests" )
141
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
142
+ tests = []
143
+ for test in tests :
144
+ if use_host_env :
145
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
146
+ else :
147
+ session .run_always ("python" , test )
148
+
149
+ def run_int8_accuracy_tests (session , use_host_env = False ):
150
+ print ("Running accuracy tests" )
151
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152
+ tests = [
153
+ "test_ptq_dataloader.py" ,
154
+ "test_ptq_to_backend.py" ,
155
+ "test_qat_trt_accuracy" ,
156
+ ]
157
+ for test in tests :
158
+ if use_host_env :
159
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
160
+ else :
161
+ session .run_always ("python" , test )
162
+
163
+ def run_trt_compatibility_tests (session , use_host_env = False ):
164
+ print ("Running TensorRT compatibility tests" )
165
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
166
+ tests = [
167
+ "test_trt_intercompatibilty.py" ,
168
+ "test_ptq_trt_calibrator.py" ,
169
+ ]
170
+ for test in tests :
171
+ if use_host_env :
172
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
173
+ else :
174
+ session .run_always ("python" , test )
175
+
176
+ def run_dla_tests (session , use_host_env = False ):
177
+ print ("Running DLA tests" )
178
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
179
+ tests = [
180
+ "test_api_dla.py" ,
181
+ ]
182
+ for test in tests :
183
+ if use_host_env :
184
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
185
+ else :
186
+ session .run_always ("python" , test )
187
+
188
+ def run_multi_gpu_tests (session , use_host_env = False ):
189
+ print ("Running multi GPU tests" )
190
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
191
+ tests = [
192
+ "test_multi_gpu.py" ,
193
+ ]
194
+ for test in tests :
195
+ if use_host_env :
196
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
197
+ else :
198
+ session .run_always ("python" , test )
199
+
200
+ def run_l0_api_tests (session , use_host_env = False ):
201
+ if not use_host_env :
202
+ install_deps (session )
203
+ install_torch_trt (session )
204
+ download_models (session , use_host_env )
205
+ run_base_tests (session , use_host_env )
206
+ cleanup (session )
207
+
208
+ def run_l0_dla_tests (session , use_host_env = False ):
209
+ if not use_host_env :
210
+ install_deps (session )
211
+ install_torch_trt (session )
212
+ download_models (session , use_host_env )
213
+ run_base_tests (session , use_host_env )
214
+ cleanup (session )
215
+
216
+ def run_l1_accuracy_tests (session , use_host_env = False ):
217
+ if not use_host_env :
218
+ install_deps (session )
219
+ install_torch_trt (session )
220
+ download_models (session , use_host_env )
221
+ download_datasets (session , use_host_env )
222
+ train_model (session , use_host_env )
223
+ run_accuracy_tests (session , use_host_env )
224
+ cleanup (session )
225
+
226
+ def run_l1_int8_accuracy_tests (session , use_host_env = False ):
227
+ if not use_host_env :
228
+ install_deps (session )
229
+ install_torch_trt (session )
230
+ download_models (session , use_host_env )
231
+ download_datasets (session , use_host_env )
232
+ train_model (session , use_host_env )
233
+ finetune_model (session , use_host_env )
234
+ run_int8_accuracy_tests (session , use_host_env )
235
+ cleanup (session )
236
+
237
+ def run_l2_trt_compatibility_tests (session , use_host_env = False ):
238
+ if not use_host_env :
239
+ install_deps (session )
240
+ install_torch_trt (session )
241
+ download_models (session , use_host_env )
242
+ run_trt_compatibility_tests (session , use_host_env )
243
+ cleanup (session )
244
+
245
+ def run_l2_multi_gpu_tests (session , use_host_env = False ):
246
+ if not use_host_env :
247
+ install_deps (session )
248
+ install_torch_trt (session )
249
+ download_models (session , use_host_env )
250
+ run_multi_gpu_tests (session , use_host_env )
251
+ cleanup (session )
252
+
253
+ @nox .session (python = ["3" ], reuse_venv = True )
254
+ def l0_api_tests (session ):
255
+ """When a developer needs to check correctness for a PR or something"""
256
+ run_l0_api_tests (session , use_host_env = False )
257
+
258
+ @nox .session (python = ["3" ], reuse_venv = True )
259
+ def l0_api_tests_host_deps (session ):
260
+ """When a developer needs to check basic api functionality using host dependencies"""
261
+ run_l0_api_tests (session , use_host_env = True )
262
+
263
+ @nox .session (python = ["3" ], reuse_venv = True )
264
+ def l0_dla_tests_host_deps (session ):
265
+ """When a developer needs to check basic api functionality using host dependencies"""
266
+ run_l0_dla_tests (session , use_host_env = True )
267
+
268
+ @nox .session (python = ["3" ], reuse_venv = True )
269
+ def l1_accuracy_tests (session ):
270
+ """Checking accuracy performance on various usecases"""
271
+ run_l1_accuracy_tests (session , use_host_env = False )
272
+
273
+ @nox .session (python = ["3" ], reuse_venv = True )
274
+ def l1_accuracy_tests_host_deps (session ):
275
+ """Checking accuracy performance on various usecases using host dependencies"""
276
+ run_l1_accuracy_tests (session , use_host_env = True )
277
+
278
+ @nox .session (python = ["3" ], reuse_venv = True )
279
+ def l1_int8_accuracy_tests (session ):
280
+ """Checking accuracy performance on various usecases"""
281
+ run_l1_int8_accuracy_tests (session , use_host_env = False )
282
+
283
+ @nox .session (python = ["3" ], reuse_venv = True )
284
+ def l1_int8_accuracy_tests_host_deps (session ):
285
+ """Checking accuracy performance on various usecases using host dependencies"""
286
+ run_l1_int8_accuracy_tests (session , use_host_env = True )
287
+
288
+ @nox .session (python = ["3" ], reuse_venv = True )
289
+ def l2_trt_compatibility_tests (session ):
290
+ """Makes sure that TensorRT Python and Torch-TensorRT can work together"""
291
+ run_l2_trt_compatibility_tests (session , use_host_env = False )
292
+
293
+ @nox .session (python = ["3" ], reuse_venv = True )
294
+ def l2_trt_compatibility_tests_host_deps (session ):
295
+ """Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
296
+ run_l2_trt_compatibility_tests (session , use_host_env = True )
297
+
298
+ @nox .session (python = ["3" ], reuse_venv = True )
299
+ def l2_multi_gpu_tests (session ):
300
+ """Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
301
+ run_l2_multi_gpu_tests (session , use_host_env = False )
302
+
303
+ @nox .session (python = ["3" ], reuse_venv = True )
304
+ def l2_multi_gpu_tests_host_deps (session ):
305
+ """Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
306
+ run_l2_multi_gpu_tests (session , use_host_env = True )
0 commit comments