Skip to content

Commit 68e00d8

Browse files
authored
Merge pull request #723 from jaybdub/add_dla_support
Add DLA support
2 parents fb52a66 + a3dba66 commit 68e00d8

File tree

1 file changed

+106
-13
lines changed

1 file changed

+106
-13
lines changed

torch2trt/torch2trt.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,60 +339,144 @@ def default_output_names(num_outputs):
339339
return ["output_%d" % i for i in range(num_outputs)]
340340

341341

342-
class LayerNamingNetworkWrapper(object):
342+
def device_type_str(device_type):
343+
if device_type == trt.DeviceType.GPU:
344+
return 'GPU'
345+
elif device_type == trt.DeviceType.DLA:
346+
return 'DLA'
347+
348+
349+
class NetworkWrapper(object):
343350
def __init__(self, ctx, network):
344351
self._ctx = ctx
345352
self._network = network
346353
self._layer_counts = defaultdict(lambda: 0)
347354

348-
def _set_layer_name(self, layer):
355+
def _configure_layer(self, layer):
356+
357+
# set layer device type
358+
device_type = self._ctx.current_device_type()
359+
self._ctx.builder_config.set_device_type(layer, device_type)
360+
orig_device_type = device_type
361+
if not self._ctx.builder_config.can_run_on_DLA(layer) and device_type == trt.DeviceType.DLA:
362+
if self._ctx.torch2trt_kwargs['gpu_fallback']:
363+
device_type = trt.DeviceType.GPU # layer will fall back to GPU
364+
365+
# set layer name
349366
def arg_str(arg):
350367
if isinstance(arg, torch.Tensor):
351368
return "tensor(shape=%s, dtype=%s)" % (str(list(arg.shape)), str(arg.dtype))
352369
return str(arg)
353-
354-
self._layer_counts[layer.type.name] += 1
370+
scope_name = self._ctx.current_module_name()# + ':' + layer.type.name
371+
self._layer_counts[scope_name] += 1
355372
args = [arg_str(arg) for arg in self._ctx.method_args]
356373
kwargs = ["%s=%s" % (key, arg_str(arg)) for key, arg in self._ctx.method_kwargs.items()]
357-
layer.name = "[%s #%d] %s(%s)" % (layer.type.name, self._layer_counts[layer.type.name],
358-
self._ctx.method_str, ", ".join(args + kwargs))
359-
374+
layer.name = scope_name + ':' + str(self._layer_counts[scope_name] - 1) + ':' + layer.type.name + ':' + device_type_str(device_type)
375+
376+
if orig_device_type != device_type:
377+
layer.name = layer.name + '(' + device_type_str(orig_device_type) + ')'
378+
# "%s [%s #%d, %s] %s(%s)" % (self._ctx.current_module_name(), layer.type.name, self._layer_counts[layer.type.name], device_type_str(device_type),
379+
# self._ctx.method_str, ", ".join(args + kwargs))
380+
381+
360382
def __getattr__(self, name):
361383
attr = getattr(self._network, name)
362384
if callable(attr):
363385
def wrapper(*args, **kwargs):
364386
ret = attr(*args, **kwargs)
365387
if isinstance(ret, trt.ILayer):
366-
self._set_layer_name(ret)
388+
self._configure_layer(ret)
367389
return ret
368390

369391
return wrapper
370392
else:
371393
return attr
372394

373395

374-
class ConversionContext(object):
375396

376-
def __init__(self, network, converters=CONVERTERS, torch2trt_kwargs=None):
377-
self.network = LayerNamingNetworkWrapper(self, network)
397+
class ConversionContext(object):
398+
399+
def __init__(self, network, converters=CONVERTERS, torch2trt_kwargs=None, builder_config=None):
400+
self.network = NetworkWrapper(self, network)
378401
self.lock = False
379402
self.method_args = None
380403
self.method_kwargs = None
381404
self.method_return = None
382405
self.torch2trt_kwargs = torch2trt_kwargs
406+
self.builder_config = builder_config
383407
self.hooks = [
384408
ConversionHook(self, key, converter)
385409
for key, converter in converters.items()
386410
]
387-
411+
412+
self.module_stack = []
413+
self.module_handles = []
414+
self.device_type_stack = []
415+
self.module_name_map = {}
416+
for name, module in torch2trt_kwargs['module'].named_modules():
417+
self.module_name_map[module] = name
418+
419+
def current_module_name(self):
420+
return self.get_module_name(self.current_module())
421+
422+
def current_module(self):
423+
return self.module_stack[-1]
424+
425+
def get_module_name(self, module):
426+
return self.module_name_map[module]
427+
428+
def _module_pre_hook(self, module, input):
429+
# TODO(@jwelsh): add logging to show module entry / exit
430+
self.module_stack.append(module)
431+
432+
# hook that is attached to modulee using register_forward_pre_hook, which is called before module is executed
433+
if module in self.torch2trt_kwargs['device_types']:
434+
device_type = self.torch2trt_kwargs['device_types'][module]
435+
self.device_type_stack.append((module, device_type))
436+
437+
def _module_post_hook(self, module, input, output):
438+
439+
# if module was used to set the current device type, pop device type from stack
440+
if self.current_device_type_module() == module:
441+
self.device_type_stack.pop()
442+
443+
self.module_stack.pop()
444+
445+
def current_device_type(self):
446+
"""Returns the current device type"""
447+
if len(self.device_type_stack) > 0:
448+
return self.device_type_stack[-1][1]
449+
else:
450+
return self.torch2trt_kwargs['default_device_type']
451+
452+
def current_device_type_module(self):
453+
"""Returns the module which controls the current device type"""
454+
if len(self.device_type_stack) > 0:
455+
return self.device_type_stack[-1][0]
456+
else:
457+
return None
458+
388459
def __enter__(self):
460+
461+
# attach hooks which add converters to methods
389462
for hook in self.hooks:
390463
hook.__enter__()
464+
465+
# attach hooks which control the current device type
466+
for name, module in self.torch2trt_kwargs['module'].named_modules():
467+
pre_hook_handle = module.register_forward_pre_hook(self._module_pre_hook)
468+
post_hook_handle = module.register_forward_hook(self._module_post_hook)
469+
self.module_handles.append(pre_hook_handle)
470+
self.module_handles.append(post_hook_handle)
471+
391472
return self
392473

393474
def __exit__(self, type, val, tb):
394475
for hook in self.hooks:
395476
hook.__exit__(type, val, tb)
477+
for handle in self.module_handles:
478+
handle.remove()
479+
396480

397481
def add_inputs(self, torch_inputs, names=None):
398482
if names is None:
@@ -508,6 +592,10 @@ def torch2trt(module,
508592
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
509593
int8_calib_batch_size=1,
510594
use_onnx=False,
595+
default_device_type=trt.DeviceType.GPU,
596+
dla_core=0,
597+
gpu_fallback=True,
598+
device_types={},
511599
**kwargs):
512600

513601
# capture arguments to provide to context
@@ -549,7 +637,7 @@ def torch2trt(module,
549637

550638
else:
551639
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
552-
with ConversionContext(network, torch2trt_kwargs=kwargs) as ctx:
640+
with ConversionContext(network, torch2trt_kwargs=kwargs, builder_config=config) as ctx:
553641

554642
ctx.add_inputs(inputs, input_names)
555643

@@ -568,6 +656,11 @@ def torch2trt(module,
568656

569657
builder.max_batch_size = max_batch_size
570658

659+
config.default_device_type = default_device_type
660+
if gpu_fallback:
661+
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
662+
config.DLA_core = dla_core
663+
571664
if strict_type_constraints:
572665
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
573666

0 commit comments

Comments
 (0)