@@ -339,60 +339,144 @@ def default_output_names(num_outputs):
339339 return ["output_%d" % i for i in range (num_outputs )]
340340
341341
342- class LayerNamingNetworkWrapper (object ):
342+ def device_type_str (device_type ):
343+ if device_type == trt .DeviceType .GPU :
344+ return 'GPU'
345+ elif device_type == trt .DeviceType .DLA :
346+ return 'DLA'
347+
348+
349+ class NetworkWrapper (object ):
343350 def __init__ (self , ctx , network ):
344351 self ._ctx = ctx
345352 self ._network = network
346353 self ._layer_counts = defaultdict (lambda : 0 )
347354
348- def _set_layer_name (self , layer ):
355+ def _configure_layer (self , layer ):
356+
357+ # set layer device type
358+ device_type = self ._ctx .current_device_type ()
359+ self ._ctx .builder_config .set_device_type (layer , device_type )
360+ orig_device_type = device_type
361+ if not self ._ctx .builder_config .can_run_on_DLA (layer ) and device_type == trt .DeviceType .DLA :
362+ if self ._ctx .torch2trt_kwargs ['gpu_fallback' ]:
363+ device_type = trt .DeviceType .GPU # layer will fall back to GPU
364+
365+ # set layer name
349366 def arg_str (arg ):
350367 if isinstance (arg , torch .Tensor ):
351368 return "tensor(shape=%s, dtype=%s)" % (str (list (arg .shape )), str (arg .dtype ))
352369 return str (arg )
353-
354- self ._layer_counts [layer . type . name ] += 1
370+ scope_name = self . _ctx . current_module_name () # + ':' + layer.type.name
371+ self ._layer_counts [scope_name ] += 1
355372 args = [arg_str (arg ) for arg in self ._ctx .method_args ]
356373 kwargs = ["%s=%s" % (key , arg_str (arg )) for key , arg in self ._ctx .method_kwargs .items ()]
357- layer .name = "[%s #%d] %s(%s)" % (layer .type .name , self ._layer_counts [layer .type .name ],
358- self ._ctx .method_str , ", " .join (args + kwargs ))
359-
374+ layer .name = scope_name + ':' + str (self ._layer_counts [scope_name ] - 1 ) + ':' + layer .type .name + ':' + device_type_str (device_type )
375+
376+ if orig_device_type != device_type :
377+ layer .name = layer .name + '(' + device_type_str (orig_device_type ) + ')'
378+ # "%s [%s #%d, %s] %s(%s)" % (self._ctx.current_module_name(), layer.type.name, self._layer_counts[layer.type.name], device_type_str(device_type),
379+ # self._ctx.method_str, ", ".join(args + kwargs))
380+
381+
360382 def __getattr__ (self , name ):
361383 attr = getattr (self ._network , name )
362384 if callable (attr ):
363385 def wrapper (* args , ** kwargs ):
364386 ret = attr (* args , ** kwargs )
365387 if isinstance (ret , trt .ILayer ):
366- self ._set_layer_name (ret )
388+ self ._configure_layer (ret )
367389 return ret
368390
369391 return wrapper
370392 else :
371393 return attr
372394
373395
374- class ConversionContext (object ):
375396
376- def __init__ (self , network , converters = CONVERTERS , torch2trt_kwargs = None ):
377- self .network = LayerNamingNetworkWrapper (self , network )
397+ class ConversionContext (object ):
398+
399+ def __init__ (self , network , converters = CONVERTERS , torch2trt_kwargs = None , builder_config = None ):
400+ self .network = NetworkWrapper (self , network )
378401 self .lock = False
379402 self .method_args = None
380403 self .method_kwargs = None
381404 self .method_return = None
382405 self .torch2trt_kwargs = torch2trt_kwargs
406+ self .builder_config = builder_config
383407 self .hooks = [
384408 ConversionHook (self , key , converter )
385409 for key , converter in converters .items ()
386410 ]
387-
411+
412+ self .module_stack = []
413+ self .module_handles = []
414+ self .device_type_stack = []
415+ self .module_name_map = {}
416+ for name , module in torch2trt_kwargs ['module' ].named_modules ():
417+ self .module_name_map [module ] = name
418+
419+ def current_module_name (self ):
420+ return self .get_module_name (self .current_module ())
421+
422+ def current_module (self ):
423+ return self .module_stack [- 1 ]
424+
425+ def get_module_name (self , module ):
426+ return self .module_name_map [module ]
427+
428+ def _module_pre_hook (self , module , input ):
429+ # TODO(@jwelsh): add logging to show module entry / exit
430+ self .module_stack .append (module )
431+
432+ # hook that is attached to modulee using register_forward_pre_hook, which is called before module is executed
433+ if module in self .torch2trt_kwargs ['device_types' ]:
434+ device_type = self .torch2trt_kwargs ['device_types' ][module ]
435+ self .device_type_stack .append ((module , device_type ))
436+
437+ def _module_post_hook (self , module , input , output ):
438+
439+ # if module was used to set the current device type, pop device type from stack
440+ if self .current_device_type_module () == module :
441+ self .device_type_stack .pop ()
442+
443+ self .module_stack .pop ()
444+
445+ def current_device_type (self ):
446+ """Returns the current device type"""
447+ if len (self .device_type_stack ) > 0 :
448+ return self .device_type_stack [- 1 ][1 ]
449+ else :
450+ return self .torch2trt_kwargs ['default_device_type' ]
451+
452+ def current_device_type_module (self ):
453+ """Returns the module which controls the current device type"""
454+ if len (self .device_type_stack ) > 0 :
455+ return self .device_type_stack [- 1 ][0 ]
456+ else :
457+ return None
458+
388459 def __enter__ (self ):
460+
461+ # attach hooks which add converters to methods
389462 for hook in self .hooks :
390463 hook .__enter__ ()
464+
465+ # attach hooks which control the current device type
466+ for name , module in self .torch2trt_kwargs ['module' ].named_modules ():
467+ pre_hook_handle = module .register_forward_pre_hook (self ._module_pre_hook )
468+ post_hook_handle = module .register_forward_hook (self ._module_post_hook )
469+ self .module_handles .append (pre_hook_handle )
470+ self .module_handles .append (post_hook_handle )
471+
391472 return self
392473
393474 def __exit__ (self , type , val , tb ):
394475 for hook in self .hooks :
395476 hook .__exit__ (type , val , tb )
477+ for handle in self .module_handles :
478+ handle .remove ()
479+
396480
397481 def add_inputs (self , torch_inputs , names = None ):
398482 if names is None :
@@ -508,6 +592,10 @@ def torch2trt(module,
508592 int8_calib_algorithm = DEFAULT_CALIBRATION_ALGORITHM ,
509593 int8_calib_batch_size = 1 ,
510594 use_onnx = False ,
595+ default_device_type = trt .DeviceType .GPU ,
596+ dla_core = 0 ,
597+ gpu_fallback = True ,
598+ device_types = {},
511599 ** kwargs ):
512600
513601 # capture arguments to provide to context
@@ -549,7 +637,7 @@ def torch2trt(module,
549637
550638 else :
551639 network = builder .create_network (1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH ))
552- with ConversionContext (network , torch2trt_kwargs = kwargs ) as ctx :
640+ with ConversionContext (network , torch2trt_kwargs = kwargs , builder_config = config ) as ctx :
553641
554642 ctx .add_inputs (inputs , input_names )
555643
@@ -568,6 +656,11 @@ def torch2trt(module,
568656
569657 builder .max_batch_size = max_batch_size
570658
659+ config .default_device_type = default_device_type
660+ if gpu_fallback :
661+ config .set_flag (trt .BuilderFlag .GPU_FALLBACK )
662+ config .DLA_core = dla_core
663+
571664 if strict_type_constraints :
572665 config .set_flag (trt .BuilderFlag .STRICT_TYPES )
573666
0 commit comments