Skip to content

Commit 00c9370

Browse files
committed
remove module test
1 parent 5345e68 commit 00c9370

File tree

2 files changed

+2
-229
lines changed

2 files changed

+2
-229
lines changed

torch2trt/module_test.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

torch2trt/test.py

Lines changed: 2 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -1,196 +1,3 @@
1-
from torch2trt import *
2-
from .module_test import ModuleTest, MODULE_TESTS
3-
import time
4-
import argparse
5-
import re
6-
import runpy
7-
import traceback
8-
from termcolor import colored
9-
import math
10-
import tempfile
11-
import numpy as np
1+
if __name__ == "__main__":
122

13-
def pSNR(model_op,trt_op):
14-
#model_op = model_op.cpu().detach().numpy().flatten()
15-
#trt_op = trt_op.cpu().detach().numpy().flatten()
16-
17-
# Calculating Mean Squared Error
18-
mse = np.sum(np.square(model_op - trt_op)) / len(model_op)
19-
# Calcuating peak signal to noise ratio
20-
try:
21-
psnr_db = 20 * math.log10(np.max(abs(model_op))) - 10 * math.log10(mse)
22-
except:
23-
psnr_db = np.nan
24-
return mse,psnr_db
25-
26-
27-
28-
def run(self, serialize=False):
29-
# create module
30-
module = self.module_fn()
31-
module = module.to(self.device)
32-
module = module.type(self.dtype)
33-
module = module.eval()
34-
35-
# create inputs for conversion
36-
inputs_conversion = ()
37-
for shape in self.input_shapes:
38-
inputs_conversion += (torch.zeros(shape).to(self.device).type(self.dtype), )
39-
40-
41-
# convert module
42-
module_trt = torch2trt(module, inputs_conversion, max_workspace_size=1 << 20, **self.torch2trt_kwargs)
43-
44-
if serialize:
45-
with tempfile.TemporaryFile() as f:
46-
torch.save(module_trt.state_dict(), f)
47-
f.seek(0)
48-
module_trt = TRTModule()
49-
module_trt.load_state_dict(torch.load(f))
50-
51-
# create inputs for torch/trt.. copy of inputs to handle inplace ops
52-
inputs = ()
53-
for shape in self.input_shapes:
54-
inputs += (torch.randn(shape).to(self.device).type(self.dtype), )
55-
inputs_trt = tuple([tensor.clone() for tensor in inputs])
56-
57-
58-
# test output against original
59-
outputs = module(*inputs)
60-
outputs_trt = module_trt(*inputs_trt)
61-
62-
if not isinstance(outputs, tuple):
63-
outputs = (outputs, )
64-
if not isinstance(outputs_trt, tuple):
65-
outputs_trt = (outputs_trt,)
66-
67-
# compute max error
68-
max_error = 0
69-
for i in range(len(outputs)):
70-
max_error_i = 0
71-
if outputs[i].dtype == torch.bool:
72-
max_error_i = torch.sum(outputs[i] ^ outputs_trt[i])
73-
else:
74-
max_error_i = torch.max(torch.abs(outputs[i] - outputs_trt[i]))
75-
76-
if max_error_i > max_error:
77-
max_error = max_error_i
78-
79-
## calculate peak signal to noise ratio
80-
assert(len(outputs) == len(outputs_trt))
81-
82-
## Check if output is boolean
83-
# if yes, then dont calculate psnr
84-
if outputs[0].dtype == torch.bool:
85-
mse = np.nan
86-
psnr_db = np.nan
87-
else:
88-
model_op = []
89-
trt_op = []
90-
for i in range(len(outputs)):
91-
model_op.extend(outputs[i].detach().cpu().numpy().flatten())
92-
trt_op.extend(outputs_trt[i].detach().cpu().numpy().flatten())
93-
model_op = np.array(model_op)
94-
trt_op = np.array(trt_op)
95-
mse,psnr_db = pSNR(model_op,trt_op)
96-
97-
# benchmark pytorch throughput
98-
torch.cuda.current_stream().synchronize()
99-
t0 = time.time()
100-
for i in range(50):
101-
outputs = module(*inputs)
102-
torch.cuda.current_stream().synchronize()
103-
t1 = time.time()
104-
105-
fps = 50.0 / (t1 - t0)
106-
107-
# benchmark tensorrt throughput
108-
torch.cuda.current_stream().synchronize()
109-
t0 = time.time()
110-
for i in range(50):
111-
outputs = module_trt(*inputs)
112-
torch.cuda.current_stream().synchronize()
113-
t1 = time.time()
114-
115-
fps_trt = 50.0 / (t1 - t0)
116-
117-
# benchmark pytorch latency
118-
torch.cuda.current_stream().synchronize()
119-
t0 = time.time()
120-
for i in range(50):
121-
outputs = module(*inputs)
122-
torch.cuda.current_stream().synchronize()
123-
t1 = time.time()
124-
125-
ms = 1000.0 * (t1 - t0) / 50.0
126-
127-
# benchmark tensorrt latency
128-
torch.cuda.current_stream().synchronize()
129-
t0 = time.time()
130-
for i in range(50):
131-
outputs = module_trt(*inputs)
132-
torch.cuda.current_stream().synchronize()
133-
t1 = time.time()
134-
135-
ms_trt = 1000.0 * (t1 - t0) / 50.0
136-
137-
return max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt
138-
139-
140-
if __name__ == '__main__':
141-
142-
parser = argparse.ArgumentParser()
143-
parser.add_argument('--output', '-o', help='Test output file path', type=str, default='torch2trt_test.md')
144-
parser.add_argument('--name', help='Regular expression to filter modules to test by name', type=str, default='.*')
145-
parser.add_argument('--tolerance', help='Maximum error to print warning for entry', type=float, default='-1')
146-
parser.add_argument('--include', help='Addition python file to include defining additional tests', action='append', default=[])
147-
parser.add_argument('--use_onnx', help='Whether to test using ONNX or torch2trt tracing', action='store_true')
148-
parser.add_argument('--serialize', help='Whether to use serialization / deserialization of TRT modules before test', action='store_true')
149-
args = parser.parse_args()
150-
151-
for include in args.include:
152-
runpy.run_module(include)
153-
154-
num_tests, num_success, num_tolerance, num_error, num_tolerance_psnr = 0, 0, 0, 0, 0
155-
for test in MODULE_TESTS:
156-
157-
# filter by module name
158-
name = test.module_name()
159-
if not re.search(args.name, name):
160-
continue
161-
162-
num_tests += 1
163-
# run test
164-
try:
165-
if args.use_onnx:
166-
test.torch2trt_kwargs.update({'use_onnx': True})
167-
168-
max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt = run(test, serialize=args.serialize)
169-
170-
# write entry
171-
line = '| %70s | %s | %25s | %s | %.2E | %.2f | %.2E | %.3g | %.3g | %.3g | %.3g |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs), max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt)
172-
173-
if args.tolerance >= 0 and max_error > args.tolerance:
174-
print(colored(line, 'yellow'))
175-
num_tolerance += 1
176-
elif psnr_db < 100:
177-
print(colored(line, 'magenta'))
178-
num_tolerance_psnr +=1
179-
else:
180-
print(line)
181-
num_success += 1
182-
except:
183-
line = '| %s | %s | %s | %s | N/A | N/A | N/A | N/A | N/A |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs))
184-
print(colored(line, 'red'))
185-
num_error += 1
186-
tb = traceback.format_exc()
187-
print(tb)
188-
189-
with open(args.output, 'a+') as f:
190-
f.write(line + '\n')
191-
192-
print('NUM_TESTS: %d' % num_tests)
193-
print('NUM_SUCCESSFUL_CONVERSION: %d' % num_success)
194-
print('NUM_FAILED_CONVERSION: %d' % num_error)
195-
print('NUM_ABOVE_TOLERANCE: %d' % num_tolerance)
196-
print('NUM_pSNR_TOLERANCE: %d' %num_tolerance_psnr)
3+
print("torch2trt.test is no longer supported. Please implement unit tests in the tests directory instead.")

0 commit comments

Comments
 (0)