1212 import onnx
1313 import onnx_graphsurgeon as gs
1414 import tensorrt as trt
15-
15+ import pycuda . driver as cuda
1616 prod_package_error = None
1717except Exception as prod_package_error :
1818 pass
2121from contextlib import redirect_stdout , ExitStack
2222from alonet .torch2trt .onnx_hack import scope_name_workaround , get_scope_names , rename_tensors_
2323from alonet .torch2trt import TRTEngineBuilder , TRTExecutor , utils
24+ from alonet .torch2trt .utils import get_nodes_by_op , rename_nodes_
25+
2426
2527
2628class BaseTRTExporter :
@@ -51,6 +53,7 @@ def __init__(
5153 operator_export_type = None ,
5254 dynamic_axes : Union [Dict [str , Dict [int , str ]], Dict [str , List [int ]]] = None ,
5355 opt_profiles : Dict [str , Tuple [List [int ]]] = None ,
56+ skip_adapt_graph = False ,
5457 ** kwargs ,
5558 ):
5659 """
@@ -108,6 +111,7 @@ def __init__(
108111 self .custom_opset = None # to be redefine in child class if needed
109112 self .use_scope_names = use_scope_names
110113 self .operator_export_type = operator_export_type
114+ self .skip_adapt_graph = skip_adapt_graph
111115 if dynamic_axes is not None :
112116 assert opt_profiles is not None , "If dynamic_axes are to be used, opt_profiles must be provided"
113117 assert isinstance (dynamic_axes , dict )
@@ -117,13 +121,19 @@ def __init__(
117121 onnx_dir = os .path .split (onnx_path )[0 ]
118122 onnx_file_name = os .path .split (onnx_path )[1 ]
119123 model_name = onnx_file_name .split ("." )[0 ]
120- self .adapted_onnx_path = os .path .join (onnx_dir , "trt_" + onnx_file_name )
124+
125+ if not self .skip_adapt_graph :
126+ self .adapted_onnx_path = os .path .join (onnx_dir , "trt_" + onnx_file_name )
127+ else :
128+ self .adapted_onnx_path = os .path .join (onnx_dir , onnx_file_name )
129+
121130 self .engine_path = os .path .join (onnx_dir , model_name + f"_{ precision .lower ()} .engine" )
122131
123132 if self .verbose :
124133 trt_logger = trt .Logger (trt .Logger .VERBOSE )
125134 else :
126135 trt_logger = trt .Logger (trt .Logger .WARNING )
136+
127137 self .engine_builder = TRTEngineBuilder (self .adapted_onnx_path , logger = trt_logger , opt_profiles = opt_profiles )
128138
129139 if precision .lower () == "fp32" :
@@ -147,15 +157,59 @@ def build_torch_model(self):
147157 pass
148158 raise Exception ("Child class should implement this method" )
149159
160+
150161 def adapt_graph (self , graph ):
151162 """Modify ONNX graph to ensure compability between ONNX and TensorRT
152163
153164 Returns
154165 -------
155166 graph: onnx_graphsurgeon.Graph
156167 """
157- pass
158- raise Exception ("Child class should implement this method" )
168+ return graph
169+
170+ def _adapt_graph (self , graph ):
171+ """Modify ONNX graph to ensure compability between ONNX and TensorRT
172+
173+ Returns
174+ -------
175+ graph: onnx_graphsurgeon.Graph
176+ """
177+ clip_nodes = get_nodes_by_op ("Clip" , graph )
178+ def handle_op_Clip (node : gs .Node ):
179+ max_constant = np .array (np .finfo (np .float32 ).max , dtype = np .float32 )
180+ if "value" in node .inputs [1 ].i ().inputs [0 ].attrs :
181+ min_constant = node .inputs [1 ].i ().inputs [0 ].attrs ["value" ].values .astype (np .float32 )
182+ if len (node .inputs [2 ].inputs ) > 0 :
183+ max_constant = node .inputs [2 ].i ().inputs [0 ].attrs ["value" ].values .astype (np .float32 )
184+ elif "to" in node .inputs [1 ].i ().inputs [0 ].attrs :
185+ min_constant = np .array (np .finfo (np .float32 ).min , dtype = np .float32 )
186+ else :
187+ raise Exception ("Error" )
188+ node .inputs .pop (1 )
189+ node .inputs .insert (1 , gs .Constant (name = node .name + "_min" , values = min_constant ))
190+ node .inputs .pop (2 )
191+ node .inputs .insert (2 , gs .Constant (name = node .name + "_max" , values = max_constant ))
192+
193+ for n in clip_nodes :
194+ handle_op_Clip (n )
195+
196+ from onnxsim import simplify
197+ model = onnx .load (self .onnx_path )
198+ check = False
199+ model_simp , check = simplify (model )
200+
201+ if check :
202+ print ("\n [INFO] Simplified ONNX model validated. Graph optimized..." )
203+ graph = gs .import_onnx (model_simp )
204+ graph .toposort ()
205+ graph .cleanup ()
206+ else :
207+ print ("\n [INFO] ONNX model was not validated." )
208+
209+
210+ # Call the child class for specific graph adapation
211+ graph = self .adapt_graph (graph )
212+ return graph
159213
160214 def prepare_sample_inputs (self ) -> Tuple [Tuple [torch .Tensor ], Dict [str , Union [torch .Tensor , None ]]]:
161215 """
@@ -247,6 +301,7 @@ def _torch2onnx(self):
247301 number2scope = get_scope_names (onnx_export_log , strict = False )
248302 graph = gs .import_onnx (onnx .load (self .onnx_path ))
249303 graph = rename_tensors_ (graph , number2scope , verbose = True )
304+ graph = rename_nodes_ (graph , True )
250305 onnx .save (gs .export_onnx (graph ), self .onnx_path )
251306
252307 print ("Saved ONNX at:" , self .onnx_path )
@@ -265,15 +320,15 @@ def _onnx2engine(self, **kwargs):
265320 if prod_package_error is not None :
266321 raise prod_package_error
267322
268- graph = gs .import_onnx (onnx .load (self .onnx_path ))
269- graph .toposort ()
270-
271- # === Modify ONNX graph for TensorRT compability
272- graph = self .adapt_graph (graph , ** kwargs )
273- utils .print_graph_io (graph )
323+ if not self .skip_adapt_graph :
324+ graph = gs .import_onnx (onnx .load (self .onnx_path ))
325+ graph .toposort ()
274326
275- # === Export adapted onnx for TRT engine
276- onnx .save (gs .export_onnx (graph ), self .adapted_onnx_path )
327+ # === Modify ONNX graph for TensorRT compability
328+ graph = self ._adapt_graph (graph , ** kwargs )
329+ utils .print_graph_io (graph )
330+ # === Export adapted onnx for TRT engine
331+ onnx .save (gs .export_onnx (graph ), self .adapted_onnx_path )
277332
278333 # === Build engine
279334 self .engine_builder .export_engine (self .engine_path )
@@ -286,7 +341,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
286341 threshold = 1e-1
287342 check = True
288343 # Get engine info
289- model = TRTExecutor (engine )
344+ model = TRTExecutor (engine , stream = cuda . Stream () )
290345 model .print_bindings_info ()
291346 # Prepare engine inputs
292347 for i in range (len (sample_inputs )):
@@ -302,6 +357,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
302357 m_outputs = model .execute ()
303358 print ("==== Absolute / relavtive error:" )
304359 for out in m_outputs :
360+ print ('out' , m_outputs [out ])
305361 diff = m_outputs [out ].astype (float ) - sample_outputs [out ].astype (float )
306362 abs_err = np .abs (diff )
307363 rel_err = np .abs (diff / (sample_outputs [out ] + 1e-6 )) # Avoid div by zero
@@ -332,7 +388,13 @@ def add_argparse_args(parent_parser):
332388 default = None ,
333389 help = "/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx" ,
334390 )
391+ parser .add_argument ("--skip_adapt_graph" , action = "store_true" , help = "Skip the adapt graph" )
335392 parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Engine batch size, default = 1" )
336393 parser .add_argument ("--precision" , type = str , default = "fp32" , help = "fp32/fp16/mix, default FP32" )
337394 parser .add_argument ("--verbose" , action = "store_true" , help = "Helpful when debugging" )
395+ parser .add_argument (
396+ "--use_scope_names" ,
397+ action = "store_true" ,
398+ help = "Save scope names in onnx, to get profiles in inference by default %(default)s" ,
399+ )
338400 return parent_parser
0 commit comments