Skip to content

Commit 2ea309a

Browse files
authored
Merge pull request #371 from andi4191/anuragd/multi-gpu
Anuragd/multi-gpu (//py): Fixed multi-gpu scenario with Python set_device API support
2 parents 26d5c65 + c1ab5db commit 2ea309a

File tree

6 files changed

+101
-7
lines changed

6 files changed

+101
-7
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
4747
util::logging::get_logger().get_reportable_severity(),
4848
util::logging::get_logger().get_is_colored_output_on()) {
4949
// TODO: Support FP16 and FP32 from JIT information
50+
if (settings.device.gpu_id) {
51+
TRTORCH_CHECK(
52+
cudaSetDevice(settings.device.gpu_id) == cudaSuccess, "Unable to set gpu id: " << settings.device.gpu_id);
53+
}
54+
5055
builder = nvinfer1::createInferBuilder(logger);
5156
net = builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
5257

@@ -108,10 +113,6 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
108113
cfg->setDefaultDeviceType(settings.device.device_type);
109114
cfg->setEngineCapability(settings.capability);
110115

111-
if (settings.device.gpu_id) {
112-
TRTORCH_CHECK(cudaSetDevice(settings.device.gpu_id), "Unable to set gpu id: " << settings.device.gpu_id);
113-
}
114-
115116
if (settings.device.device_type == nvinfer1::DeviceType::kDLA) {
116117
auto nbDLACores = builder->getNbDLACores();
117118
TRTORCH_CHECK(

docsrc/py_api/trtorch.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ trtorch
1111
Functions
1212
------------
1313

14+
.. autofunction:: set_device
15+
1416
.. autofunction:: compile
1517

1618
.. autofunction:: convert_method_to_trt_engine

py/trtorch/_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,6 @@ def get_build_info() -> str:
156156
build_info = trtorch._C.get_build_info()
157157
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
158158
return build_info
159+
160+
def set_device(gpu_id):
161+
trtorch._C.set_device(gpu_id)

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ namespace py = pybind11;
1515
namespace trtorch {
1616
namespace pyapi {
1717

18+
void set_device(const int device_id) {
19+
core::set_device(device_id);
20+
}
21+
1822
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) {
1923
py::gil_scoped_acquire gil;
2024
auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec());
@@ -146,6 +150,7 @@ PYBIND11_MODULE(_C, m) {
146150
m.def("_get_is_colored_output_on", &logging::get_is_colored_output_on, "Get if the logging output will be colored");
147151
m.def("_set_is_colored_output_on", &logging::set_is_colored_output_on, "Set if the logging output should be colored");
148152
m.def("_log", &logging::log, "Add a message to the logger");
153+
m.def("set_device", &trtorch::pyapi::set_device, "Set CUDA device id");
149154

150155
py::enum_<core::util::logging::LogLevel>(m, "LogLevel", py::arithmetic())
151156
.value("INTERNAL_ERROR", core::util::logging::LogLevel::kINTERNAL_ERROR)

tests/py/BUILD

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,23 @@ py_test(
1515
"test_api.py",
1616
"model_test_case.py"
1717
] + select({
18-
":aarch64_linux": [
19-
"test_api_dla.py"
20-
],
18+
":aarch64_linux": [
19+
"test_api_dla.py"
20+
],
21+
"//conditions:default" : []
22+
}),
23+
deps = [
24+
requirement("torchvision")
25+
]
26+
)
27+
28+
# Following multi_gpu test is only targeted for multi-gpu configurations. It is not included in the test suite by default.
29+
py_test(
30+
name = "test_multi_gpu",
31+
srcs = [
32+
"test_multi_gpu.py",
33+
"model_test_case.py"
34+
],
2135
"//conditions:default" : []
2236
}),
2337
deps = [

tests/py/test_multi_gpu.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
6+
from model_test_case import ModelTestCase
7+
8+
class TestMultiGpuSwitching(ModelTestCase):
9+
def setUp(self):
10+
if torch.cuda.device_count() < 2:
11+
self.fail("Test is not relevant for this platform since number of available CUDA devices is less than 2")
12+
13+
trtorch.set_device(0)
14+
self.target_gpu = 1
15+
self.input = torch.randn((1, 3, 224, 224)).to("cuda:1")
16+
self.model = self.model.to("cuda:1")
17+
self.traced_model = torch.jit.trace(self.model, [self.input])
18+
self.scripted_model = torch.jit.script(self.model)
19+
20+
def test_compile_traced(self):
21+
trtorch.set_device(0)
22+
compile_spec = {
23+
"input_shapes": [self.input.shape],
24+
"device": {
25+
"device_type": trtorch.DeviceType.GPU,
26+
"gpu_id": self.target_gpu,
27+
"dla_core": 0,
28+
"allow_gpu_fallback": False,
29+
"disable_tf32": False
30+
}
31+
}
32+
33+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
34+
trtorch.set_device(self.target_gpu)
35+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
36+
trtorch.set_device(0)
37+
self.assertTrue(same < 2e-3)
38+
39+
def test_compile_script(self):
40+
trtorch.set_device(0)
41+
compile_spec = {
42+
"input_shapes": [self.input.shape],
43+
"device": {
44+
"device_type": trtorch.DeviceType.GPU,
45+
"gpu_id": self.target_gpu,
46+
"dla_core": 0,
47+
"allow_gpu_fallback": False,
48+
"disable_tf32": False
49+
}
50+
}
51+
52+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
53+
trtorch.set_device(self.target_gpu)
54+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
55+
trtorch.set_device(0)
56+
self.assertTrue(same < 2e-3)
57+
58+
def test_suite():
59+
suite = unittest.TestSuite()
60+
suite.addTest(TestMultiGpuSwitching.parametrize(TestMultiGpuSwitching, model=models.resnet18(pretrained=True)))
61+
62+
return suite
63+
64+
suite = test_suite()
65+
66+
runner = unittest.TextTestRunner()
67+
result = runner.run(suite)
68+
69+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)