11import torch
22import tensorrt as trt
3- from copy import copy
3+ import copy
44import numpy as np
55import io
66from collections import defaultdict
7+ import importlib
78
89from .calibration import (
910 TensorBatchDataset ,
@@ -297,30 +298,24 @@ def wrapper(*args, **kwargs):
297298class ConversionHook (object ):
298299 """Attaches TensorRT converter to PyTorch method call"""
299300
300- def __init__ (self , ctx , method , converter ):
301+ def __init__ (self , ctx , key , converter ):
301302 self .ctx = ctx
302- self .method_str = method
303+ self .key = key
303304 self .converter = converter
304305
305306 def _set_method (self , method ):
306- exec ("%s = method" % self .method_str )
307+ module = self .converter ['module' ]
308+ exec ('module.%s = method' % self .converter ['qual_name' ])
307309
308310 def __enter__ (self ):
309- try :
310- self .method_impl = eval (self .method_str )
311- except AttributeError :
312- self .method_impl = None
313-
314- if self .method_impl :
315- self ._set_method (
316- attach_converter (
317- self .ctx , self .method_impl , self .converter , self .method_str
318- )
311+ self ._set_method (
312+ attach_converter (
313+ self .ctx , self .converter ['method_impl' ], self .converter , self .converter ['method_str' ]
319314 )
315+ )
320316
321317 def __exit__ (self , type , val , tb ):
322- if self .method_impl :
323- self ._set_method (self .method_impl )
318+ self ._set_method (self .converter ['method_impl' ])
324319
325320def default_input_names (num_inputs ):
326321 return ["input_%d" % i for i in range (num_inputs )]
@@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS):
369364 self .method_kwargs = None
370365 self .method_return = None
371366 self .hooks = [
372- ConversionHook (self , method , converter )
373- for method , converter in converters .items ()
367+ ConversionHook (self , key , converter )
368+ for key , converter in converters .items ()
374369 ]
375370
376371 def __enter__ (self ):
@@ -569,11 +564,40 @@ def torch2trt(module,
569564
570565# DEFINE ALL CONVERSION FUNCTIONS
571566
567+ def get_module_qualname (name ):
568+ s = name .split ('.' )
569+
570+ for i in range (len (s )):
571+ idx = len (s ) - i - 1
572+ modulename , qualname = "." .join (s [:idx ]), "." .join (s [idx :])
573+ try :
574+ module = importlib .import_module (modulename )
575+ return module , modulename , qualname
576+ except :
577+ pass
578+
579+ raise RuntimeError ("Could not import module" )
580+
572581
573- def tensorrt_converter (method , is_real = True , enabled = True ):
574-
582+ def tensorrt_converter (method , is_real = True , enabled = True , imports = []):
583+
584+ if isinstance (method , str ):
585+ module , module_name , qual_name = get_module_qualname (method )
586+ else :
587+ module , module_name , qual_name = importlib .import_module (method .__module__ ), method .__module__ , method .__qualname__
588+
589+ method_impl = eval ('copy.deepcopy(module.%s)' % qual_name )
590+
575591 def register_converter (converter ):
576- CONVERTERS [method ] = {"converter" : converter , "is_real" : is_real }
592+ CONVERTERS [method ] = {
593+ "converter" : converter ,
594+ "is_real" : is_real ,
595+ "module" : module ,
596+ "module_name" : module_name ,
597+ "qual_name" : qual_name ,
598+ "method_str" : module_name + '.' + qual_name ,
599+ "method_impl" : method_impl
600+ }
577601 return converter
578602
579603 def pass_converter (converter ):
@@ -584,4 +608,4 @@ def pass_converter(converter):
584608 else :
585609 return pass_converter
586610
587- return register_converter
611+ return register_converter
0 commit comments