4
4
# Use system installed Python packages
5
5
PYT_PATH = '/opt/conda/lib/python3.8/site-packages' if not 'PYT_PATH' in os .environ else os .environ ["PYT_PATH" ]
6
6
7
- # Root directory for torch_tensorrt. Set according to docker container by default
8
- TOP_DIR = '/torchtrt' if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
7
+ # Set the root directory to the directory of the noxfile unless the user wants to
8
+ # TOP_DIR
9
+ TOP_DIR = os .path .dirname (os .path .realpath (__file__ )) if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
10
+
11
+ nox .options .sessions = ["developer_tests-3" ]
12
+
13
+ def install_deps (session ):
14
+ print ("Installing deps" )
15
+ session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
16
+ session .install ("-r" , os .path .join (TOP_DIR , "tests" , "py" , "requirements.txt" ))
17
+
18
+ def download_models (session , use_host_env = False ):
19
+ print ("Downloading test models" )
20
+ session .install ('timm' )
21
+ print (TOP_DIR )
22
+ session .chdir (os .path .join (TOP_DIR , "tests" , "modules" ))
23
+ if use_host_env :
24
+ session .run_always ('python' , 'hub.py' , env = {'PYTHONPATH' : PYT_PATH })
25
+ else :
26
+ session .run_always ('python' , 'hub.py' )
27
+
28
+ def install_torch_trt (session ):
29
+ print ("Installing latest torch-tensorrt build" )
30
+ session .chdir (os .path .join (TOP_DIR , "py" ))
31
+ session .run ("python" , "setup.py" , "develop" )
32
+
33
+ def run_base_tests (session , use_host_env = False ):
34
+ print ("Running basic tests" )
35
+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
36
+ tests = [
37
+ "test_api.py" ,
38
+ "test_to_backend_api.py"
39
+ ]
40
+ for test in tests :
41
+ if use_host_env :
42
+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
43
+ else :
44
+ session .run_always ("python" , test )
45
+
46
+
47
+ # Install the latest build of torch-tensorrt
48
+ @nox .session (python = ["3" ], reuse_venv = True )
49
+ def developer_tests (session ):
50
+ """Basic set of tests that need to pass for code to get merged"""
51
+ install_deps (session )
52
+ install_torch_trt (session )
53
+ download_models (session )
54
+ run_base_tests (session )
9
55
10
56
# Download the dataset
11
57
@nox .session (python = ["3" ], reuse_venv = True )
@@ -14,33 +60,29 @@ def download_datasets(session):
14
60
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
15
61
session .run_always ('wget' , 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' , external = True )
16
62
session .run_always ('tar' , '-xvzf' , 'cifar-10-binary.tar.gz' , external = True )
17
- session .run_always ('mkdir' , '-p' ,
63
+ session .run_always ('mkdir' , '-p' ,
18
64
os .path .join (TOP_DIR , 'tests/accuracy/datasets/data' ),
19
65
external = True )
20
- session .run_always ('cp' , '-rpf' ,
66
+ session .run_always ('cp' , '-rpf' ,
21
67
os .path .join (TOP_DIR , 'examples/int8/training/vgg16/cifar-10-batches-bin' ),
22
68
os .path .join (TOP_DIR , 'tests/accuracy/datasets/data/cidar-10-batches-bin' ),
23
69
external = True )
24
70
25
71
# Download the model
26
72
@nox .session (python = ["3" ], reuse_venv = True )
27
- def download_models (session ):
28
- session .install ('timm' )
29
- session .chdir ('tests/modules' )
30
- session .run_always ('python' ,
31
- 'hub.py' ,
32
- env = {'PYTHONPATH' : PYT_PATH })
73
+ def download_test_models (session ):
74
+ download_models (session , use_host_env = True )
33
75
34
76
# Train the model
35
77
@nox .session (python = ["3" ], reuse_venv = True )
36
78
def train_model (session ):
37
79
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
38
- session .run_always ('python' ,
39
- 'main.py' ,
40
- '--lr' , '0.01' ,
41
- '--batch-size' , '128' ,
42
- '--drop-ratio' , '0.15' ,
43
- '--ckpt-dir' , 'vgg16_ckpts' ,
80
+ session .run_always ('python' ,
81
+ 'main.py' ,
82
+ '--lr' , '0.01' ,
83
+ '--batch-size' , '128' ,
84
+ '--drop-ratio' , '0.15' ,
85
+ '--ckpt-dir' , 'vgg16_ckpts' ,
44
86
'--epochs' , '25' ,
45
87
env = {'PYTHONPATH' : PYT_PATH })
46
88
@@ -57,17 +99,17 @@ def finetune_model(session):
57
99
session .install ('pytorch-quantization' , '--extra-index-url' , 'https://pypi.ngc.nvidia.com' )
58
100
59
101
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
60
- session .run_always ('python' ,
61
- 'finetune_qat.py' ,
62
- '--lr' , '0.01' ,
63
- '--batch-size' , '128' ,
64
- '--drop-ratio' , '0.15' ,
65
- '--ckpt-dir' , 'vgg16_ckpts' ,
66
- '--start-from' , '25' ,
102
+ session .run_always ('python' ,
103
+ 'finetune_qat.py' ,
104
+ '--lr' , '0.01' ,
105
+ '--batch-size' , '128' ,
106
+ '--drop-ratio' , '0.15' ,
107
+ '--ckpt-dir' , 'vgg16_ckpts' ,
108
+ '--start-from' , '25' ,
67
109
'--epochs' , '26' ,
68
110
env = {'PYTHONPATH' : PYT_PATH })
69
-
70
- # Export model
111
+
112
+ # Export model
71
113
session .run_always ('python' ,
72
114
'export_qat.py' ,
73
115
'vgg16_ckpts/ckpt_epoch26.pth' ,
@@ -77,8 +119,8 @@ def finetune_model(session):
77
119
@nox .session (python = ["3" ], reuse_venv = True )
78
120
def ptq_test (session ):
79
121
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
80
- session .run_always ('cp' , '-rf' ,
81
- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16.jit.pt' ),
122
+ session .run_always ('cp' , '-rf' ,
123
+ os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16.jit.pt' ),
82
124
'.' ,
83
125
external = True )
84
126
tests = [
@@ -94,8 +136,8 @@ def ptq_test(session):
94
136
@nox .session (python = ["3" ], reuse_venv = True )
95
137
def qat_test (session ):
96
138
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
97
- session .run_always ('cp' , '-rf' ,
98
- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16_qat.jit.pt' ),
139
+ session .run_always ('cp' , '-rf' ,
140
+ os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16_qat.jit.pt' ),
99
141
'.' ,
100
142
external = True )
101
143
0 commit comments