Skip to content

Commit 80743b0

Browse files
committed
feat: Safety Mode for Runtime (#2512)
1 parent f93a732 commit 80743b0

File tree

16 files changed

+420
-15
lines changed

16 files changed

+420
-15
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TRTEngine::TRTEngine(
5252
auto most_compatible_device = get_most_compatible_device(cuda_device);
5353
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
5454
device_info = most_compatible_device.value();
55+
multi_gpu_device_check();
5556
set_rt_device(device_info);
5657

5758
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
7474
LOG_INFO("" << log_info);
7575
}
7676

77-
{
77+
if (MULTI_DEVICE_SAFE_MODE) {
7878
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
7979
if (compiled_engine->profile_execution) {
8080
device_profiler_guard =

core/runtime/register_jit_hooks.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) {
114114
m.def("execute_engine", execute_engine);
115115
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
116116
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
117+
m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; });
118+
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
119+
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
120+
});
117121
}
118122

119123
} // namespace

core/runtime/runtime.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10+
bool MULTI_DEVICE_SAFE_MODE = false;
11+
1012
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1113
LOG_DEBUG("Target Device: " << target_device);
1214
auto device_options = find_compatible_devices(target_device);
@@ -31,13 +33,13 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
3133
if (device.device_name == target_device.device_name) {
3234
// First priority is selecting a candidate which agrees with the current device ID
3335
// If such a device is found, we can select it and break out of the loop
34-
if (device.id == current_device.id && best_match.id != current_device.id) {
36+
if (device.id == current_device.id) {
3537
best_match = device;
3638
break;
3739
}
3840
// Second priority is selecting a candidate which agrees with the target device ID
3941
// At deserialization time, the current device and target device may not agree
40-
else if (device.id == target_device.id && best_match.id != target_device.id) {
42+
else if (device.id == target_device.id) {
4143
best_match = device;
4244
}
4345
// If no such GPU ID is found, select the first available candidate GPU
@@ -103,6 +105,17 @@ RTDevice get_current_device() {
103105
return RTDevice(device_id, nvinfer1::DeviceType::kGPU);
104106
}
105107

108+
void multi_gpu_device_check() {
109+
// If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
110+
if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) {
111+
LOG_WARNING(
112+
"Detected this engine is being instantitated in a multi-GPU system with "
113+
<< "multi-device safe mode disabled. For more on the implications of this "
114+
<< "as well as workarounds, see the linked documentation "
115+
<< "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)");
116+
}
117+
}
118+
106119
namespace {
107120
static DeviceList cuda_device_list;
108121
}

core/runtime/runtime.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace runtime {
1616

1717
using EngineID = int64_t;
1818
const std::string ABI_VERSION = "4";
19+
extern bool MULTI_DEVICE_SAFE_MODE;
1920
typedef enum {
2021
ABI_TARGET_IDX = 0,
2122
NAME_IDX,
@@ -33,6 +34,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
3334

3435
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3536

37+
void multi_gpu_device_check();
38+
3639
class DeviceList {
3740
using DeviceMap = std::unordered_map<int, RTDevice>;
3841
DeviceMap device_list;

docsrc/user_guide/runtime.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,37 @@ Plugin Library
3434
In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT
3535
ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during
3636
compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries.
37+
38+
Multi Device Safe Mode
39+
---------------
40+
41+
Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether
42+
the runtime checks for device consistency prior to every inference call.
43+
44+
There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why
45+
it is now disabled by default. It can be controlled via the following convenience function which
46+
doubles as a context manager.
47+
48+
.. code-block:: python
49+
50+
# Enables Multi Device Safe Mode
51+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
52+
53+
# Disables Multi Device Safe Mode [Default Behavior]
54+
torch_tensorrt.runtime.set_multi_device_safe_mode(False)
55+
56+
# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
57+
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
58+
...
59+
60+
TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked.
61+
Therefore, if the device were to change in the active thread, which may be the case when invoking
62+
engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display
63+
an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine
64+
device and CUDA context device, which could lead the program to crash.
65+
66+
One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for
67+
multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines
68+
on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is
69+
responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same
70+
Python script without needing to switch CUDA contexts and incur performance overhead.

py/torch_tensorrt/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str:
8585
from torch_tensorrt._Device import Device # noqa: F401
8686
from torch_tensorrt._enums import * # noqa: F403
8787
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt.logging import *
89-
from torch_tensorrt.ptq import *
9088
from torch_tensorrt._utils import * # noqa: F403
9189
from torch_tensorrt._utils import sanitized_torch_version
90+
from torch_tensorrt.logging import *
91+
from torch_tensorrt.ptq import *
92+
from torch_tensorrt.runtime import * # noqa: F403
9293

9394
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
94-
from torch_tensorrt import dynamo # noqa: F401
9595
from torch_tensorrt.dynamo import backend # noqa: F401
9696

97+
from torch_tensorrt import dynamo # noqa: F401
98+
9799

98100
def _register_with_torch() -> None:
99101
trtorch_dir = os.path.dirname(__file__)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def convert_module(
6060
engine=interpreter_result.engine,
6161
input_names=list(interpreter_result.input_names),
6262
output_names=list(interpreter_result.output_names),
63+
target_device=settings.device,
64+
profiling_enabled=settings.debug,
6365
)
6466

6567
else:

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def is_node_supported(
4242
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4343

4444
if (
45-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
45+
node in CONVERTERS or node.op == "get_attr"
4646
) and node_name not in self.torch_executed_ops:
4747
# If node is a proper, supported computational node, store the operator
48-
if not node.is_impure():
48+
if not node.is_impure() and node.op != "get_attr":
4949
if node_name not in self.supported_operators:
5050
self.supported_operators[node_name] = 1
5151
else:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ def is_node_supported(
150150
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151151

152152
if (
153-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
153+
node in CONVERTERS or node.op == "get_attr"
154154
) and node_name not in self.torch_executed_ops:
155155
# If node is a proper, supported computational node, store the operator
156-
if not node.is_impure():
156+
if not node.is_impure() and node.op != "get_attr":
157157
if node_name not in self.supported_operators:
158158
self.supported_operators[node_name] = 1
159159
else:

0 commit comments

Comments
 (0)