Skip to content

Commit 81024cc

Browse files
authored
allow direct method setting to support custom layers (#460)
1 parent adccbf1 commit 81024cc

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

torch2trt/torch2trt.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import tensorrt as trt
3-
from copy import copy
3+
import copy
44
import numpy as np
55
import io
66
from collections import defaultdict
7+
import importlib
78

89
from .calibration import (
910
TensorBatchDataset,
@@ -297,30 +298,24 @@ def wrapper(*args, **kwargs):
297298
class ConversionHook(object):
298299
"""Attaches TensorRT converter to PyTorch method call"""
299300

300-
def __init__(self, ctx, method, converter):
301+
def __init__(self, ctx, key, converter):
301302
self.ctx = ctx
302-
self.method_str = method
303+
self.key = key
303304
self.converter = converter
304305

305306
def _set_method(self, method):
306-
exec("%s = method" % self.method_str)
307+
module = self.converter['module']
308+
exec('module.%s = method' % self.converter['qual_name'])
307309

308310
def __enter__(self):
309-
try:
310-
self.method_impl = eval(self.method_str)
311-
except AttributeError:
312-
self.method_impl = None
313-
314-
if self.method_impl:
315-
self._set_method(
316-
attach_converter(
317-
self.ctx, self.method_impl, self.converter, self.method_str
318-
)
311+
self._set_method(
312+
attach_converter(
313+
self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str']
319314
)
315+
)
320316

321317
def __exit__(self, type, val, tb):
322-
if self.method_impl:
323-
self._set_method(self.method_impl)
318+
self._set_method(self.converter['method_impl'])
324319

325320
def default_input_names(num_inputs):
326321
return ["input_%d" % i for i in range(num_inputs)]
@@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS):
369364
self.method_kwargs = None
370365
self.method_return = None
371366
self.hooks = [
372-
ConversionHook(self, method, converter)
373-
for method, converter in converters.items()
367+
ConversionHook(self, key, converter)
368+
for key, converter in converters.items()
374369
]
375370

376371
def __enter__(self):
@@ -569,11 +564,40 @@ def torch2trt(module,
569564

570565
# DEFINE ALL CONVERSION FUNCTIONS
571566

567+
def get_module_qualname(name):
568+
s = name.split('.')
569+
570+
for i in range(len(s)):
571+
idx = len(s) - i - 1
572+
modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:])
573+
try:
574+
module = importlib.import_module(modulename)
575+
return module, modulename, qualname
576+
except:
577+
pass
578+
579+
raise RuntimeError("Could not import module")
580+
572581

573-
def tensorrt_converter(method, is_real=True, enabled=True):
574-
582+
def tensorrt_converter(method, is_real=True, enabled=True, imports=[]):
583+
584+
if isinstance(method, str):
585+
module, module_name, qual_name = get_module_qualname(method)
586+
else:
587+
module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__
588+
589+
method_impl = eval('copy.deepcopy(module.%s)' % qual_name)
590+
575591
def register_converter(converter):
576-
CONVERTERS[method] = {"converter": converter, "is_real": is_real}
592+
CONVERTERS[method] = {
593+
"converter": converter,
594+
"is_real": is_real,
595+
"module": module,
596+
"module_name": module_name,
597+
"qual_name": qual_name,
598+
"method_str": module_name + '.' + qual_name,
599+
"method_impl": method_impl
600+
}
577601
return converter
578602

579603
def pass_converter(converter):
@@ -584,4 +608,4 @@ def pass_converter(converter):
584608
else:
585609
return pass_converter
586610

587-
return register_converter
611+
return register_converter

0 commit comments

Comments
 (0)