1- from torch2trt import *
2- from .module_test import ModuleTest , MODULE_TESTS
3- import time
4- import argparse
5- import re
6- import runpy
7- import traceback
8- from termcolor import colored
9- import math
10- import tempfile
11- import numpy as np
1+ if __name__ == "__main__" :
122
13- def pSNR (model_op ,trt_op ):
14- #model_op = model_op.cpu().detach().numpy().flatten()
15- #trt_op = trt_op.cpu().detach().numpy().flatten()
16-
17- # Calculating Mean Squared Error
18- mse = np .sum (np .square (model_op - trt_op )) / len (model_op )
19- # Calcuating peak signal to noise ratio
20- try :
21- psnr_db = 20 * math .log10 (np .max (abs (model_op ))) - 10 * math .log10 (mse )
22- except :
23- psnr_db = np .nan
24- return mse ,psnr_db
25-
26-
27-
28- def run (self , serialize = False ):
29- # create module
30- module = self .module_fn ()
31- module = module .to (self .device )
32- module = module .type (self .dtype )
33- module = module .eval ()
34-
35- # create inputs for conversion
36- inputs_conversion = ()
37- for shape in self .input_shapes :
38- inputs_conversion += (torch .zeros (shape ).to (self .device ).type (self .dtype ), )
39-
40-
41- # convert module
42- module_trt = torch2trt (module , inputs_conversion , max_workspace_size = 1 << 20 , ** self .torch2trt_kwargs )
43-
44- if serialize :
45- with tempfile .TemporaryFile () as f :
46- torch .save (module_trt .state_dict (), f )
47- f .seek (0 )
48- module_trt = TRTModule ()
49- module_trt .load_state_dict (torch .load (f ))
50-
51- # create inputs for torch/trt.. copy of inputs to handle inplace ops
52- inputs = ()
53- for shape in self .input_shapes :
54- inputs += (torch .randn (shape ).to (self .device ).type (self .dtype ), )
55- inputs_trt = tuple ([tensor .clone () for tensor in inputs ])
56-
57-
58- # test output against original
59- outputs = module (* inputs )
60- outputs_trt = module_trt (* inputs_trt )
61-
62- if not isinstance (outputs , tuple ):
63- outputs = (outputs , )
64- if not isinstance (outputs_trt , tuple ):
65- outputs_trt = (outputs_trt ,)
66-
67- # compute max error
68- max_error = 0
69- for i in range (len (outputs )):
70- max_error_i = 0
71- if outputs [i ].dtype == torch .bool :
72- max_error_i = torch .sum (outputs [i ] ^ outputs_trt [i ])
73- else :
74- max_error_i = torch .max (torch .abs (outputs [i ] - outputs_trt [i ]))
75-
76- if max_error_i > max_error :
77- max_error = max_error_i
78-
79- ## calculate peak signal to noise ratio
80- assert (len (outputs ) == len (outputs_trt ))
81-
82- ## Check if output is boolean
83- # if yes, then dont calculate psnr
84- if outputs [0 ].dtype == torch .bool :
85- mse = np .nan
86- psnr_db = np .nan
87- else :
88- model_op = []
89- trt_op = []
90- for i in range (len (outputs )):
91- model_op .extend (outputs [i ].detach ().cpu ().numpy ().flatten ())
92- trt_op .extend (outputs_trt [i ].detach ().cpu ().numpy ().flatten ())
93- model_op = np .array (model_op )
94- trt_op = np .array (trt_op )
95- mse ,psnr_db = pSNR (model_op ,trt_op )
96-
97- # benchmark pytorch throughput
98- torch .cuda .current_stream ().synchronize ()
99- t0 = time .time ()
100- for i in range (50 ):
101- outputs = module (* inputs )
102- torch .cuda .current_stream ().synchronize ()
103- t1 = time .time ()
104-
105- fps = 50.0 / (t1 - t0 )
106-
107- # benchmark tensorrt throughput
108- torch .cuda .current_stream ().synchronize ()
109- t0 = time .time ()
110- for i in range (50 ):
111- outputs = module_trt (* inputs )
112- torch .cuda .current_stream ().synchronize ()
113- t1 = time .time ()
114-
115- fps_trt = 50.0 / (t1 - t0 )
116-
117- # benchmark pytorch latency
118- torch .cuda .current_stream ().synchronize ()
119- t0 = time .time ()
120- for i in range (50 ):
121- outputs = module (* inputs )
122- torch .cuda .current_stream ().synchronize ()
123- t1 = time .time ()
124-
125- ms = 1000.0 * (t1 - t0 ) / 50.0
126-
127- # benchmark tensorrt latency
128- torch .cuda .current_stream ().synchronize ()
129- t0 = time .time ()
130- for i in range (50 ):
131- outputs = module_trt (* inputs )
132- torch .cuda .current_stream ().synchronize ()
133- t1 = time .time ()
134-
135- ms_trt = 1000.0 * (t1 - t0 ) / 50.0
136-
137- return max_error ,psnr_db ,mse , fps , fps_trt , ms , ms_trt
138-
139-
140- if __name__ == '__main__' :
141-
142- parser = argparse .ArgumentParser ()
143- parser .add_argument ('--output' , '-o' , help = 'Test output file path' , type = str , default = 'torch2trt_test.md' )
144- parser .add_argument ('--name' , help = 'Regular expression to filter modules to test by name' , type = str , default = '.*' )
145- parser .add_argument ('--tolerance' , help = 'Maximum error to print warning for entry' , type = float , default = '-1' )
146- parser .add_argument ('--include' , help = 'Addition python file to include defining additional tests' , action = 'append' , default = [])
147- parser .add_argument ('--use_onnx' , help = 'Whether to test using ONNX or torch2trt tracing' , action = 'store_true' )
148- parser .add_argument ('--serialize' , help = 'Whether to use serialization / deserialization of TRT modules before test' , action = 'store_true' )
149- args = parser .parse_args ()
150-
151- for include in args .include :
152- runpy .run_module (include )
153-
154- num_tests , num_success , num_tolerance , num_error , num_tolerance_psnr = 0 , 0 , 0 , 0 , 0
155- for test in MODULE_TESTS :
156-
157- # filter by module name
158- name = test .module_name ()
159- if not re .search (args .name , name ):
160- continue
161-
162- num_tests += 1
163- # run test
164- try :
165- if args .use_onnx :
166- test .torch2trt_kwargs .update ({'use_onnx' : True })
167-
168- max_error ,psnr_db ,mse , fps , fps_trt , ms , ms_trt = run (test , serialize = args .serialize )
169-
170- # write entry
171- line = '| %70s | %s | %25s | %s | %.2E | %.2f | %.2E | %.3g | %.3g | %.3g | %.3g |' % (name , test .dtype .__repr__ ().split ('.' )[- 1 ], str (test .input_shapes ), str (test .torch2trt_kwargs ), max_error ,psnr_db ,mse , fps , fps_trt , ms , ms_trt )
172-
173- if args .tolerance >= 0 and max_error > args .tolerance :
174- print (colored (line , 'yellow' ))
175- num_tolerance += 1
176- elif psnr_db < 100 :
177- print (colored (line , 'magenta' ))
178- num_tolerance_psnr += 1
179- else :
180- print (line )
181- num_success += 1
182- except :
183- line = '| %s | %s | %s | %s | N/A | N/A | N/A | N/A | N/A |' % (name , test .dtype .__repr__ ().split ('.' )[- 1 ], str (test .input_shapes ), str (test .torch2trt_kwargs ))
184- print (colored (line , 'red' ))
185- num_error += 1
186- tb = traceback .format_exc ()
187- print (tb )
188-
189- with open (args .output , 'a+' ) as f :
190- f .write (line + '\n ' )
191-
192- print ('NUM_TESTS: %d' % num_tests )
193- print ('NUM_SUCCESSFUL_CONVERSION: %d' % num_success )
194- print ('NUM_FAILED_CONVERSION: %d' % num_error )
195- print ('NUM_ABOVE_TOLERANCE: %d' % num_tolerance )
196- print ('NUM_pSNR_TOLERANCE: %d' % num_tolerance_psnr )
3+ print ("torch2trt.test is no longer supported. Please implement unit tests in the tests directory instead." )
0 commit comments