Skip to content

Commit 00cf1d5

Browse files
author
Anurag Dixit
committed
Added fix for mult-gpu configuration
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 26d5c65 commit 00cf1d5

File tree

6 files changed

+98
-4
lines changed

6 files changed

+98
-4
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
4747
util::logging::get_logger().get_reportable_severity(),
4848
util::logging::get_logger().get_is_colored_output_on()) {
4949
// TODO: Support FP16 and FP32 from JIT information
50+
if (settings.device.gpu_id) {
51+
TRTORCH_CHECK(
52+
cudaSetDevice(settings.device.gpu_id) == cudaSuccess, "Unable to set gpu id: " << settings.device.gpu_id);
53+
}
54+
5055
builder = nvinfer1::createInferBuilder(logger);
5156
net = builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
5257

@@ -108,10 +113,6 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
108113
cfg->setDefaultDeviceType(settings.device.device_type);
109114
cfg->setEngineCapability(settings.capability);
110115

111-
if (settings.device.gpu_id) {
112-
TRTORCH_CHECK(cudaSetDevice(settings.device.gpu_id), "Unable to set gpu id: " << settings.device.gpu_id);
113-
}
114-
115116
if (settings.device.device_type == nvinfer1::DeviceType::kDLA) {
116117
auto nbDLACores = builder->getNbDLACores();
117118
TRTORCH_CHECK(

docsrc/py_api/trtorch.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ trtorch
1111
Functions
1212
------------
1313

14+
.. autofunction:: set_device
15+
1416
.. autofunction:: compile
1517

1618
.. autofunction:: convert_method_to_trt_engine

py/trtorch/_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,6 @@ def get_build_info() -> str:
156156
build_info = trtorch._C.get_build_info()
157157
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
158158
return build_info
159+
160+
def set_device(gpu_id):
161+
trtorch._C.set_device(gpu_id)

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ namespace py = pybind11;
1515
namespace trtorch {
1616
namespace pyapi {
1717

18+
void set_device(const int device_id) {
19+
core::set_device(device_id);
20+
}
21+
1822
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) {
1923
py::gil_scoped_acquire gil;
2024
auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec());
@@ -146,6 +150,7 @@ PYBIND11_MODULE(_C, m) {
146150
m.def("_get_is_colored_output_on", &logging::get_is_colored_output_on, "Get if the logging output will be colored");
147151
m.def("_set_is_colored_output_on", &logging::set_is_colored_output_on, "Set if the logging output should be colored");
148152
m.def("_log", &logging::log, "Add a message to the logger");
153+
m.def("set_device", &trtorch::pyapi::set_device, "Set CUDA device id");
149154

150155
py::enum_<core::util::logging::LogLevel>(m, "LogLevel", py::arithmetic())
151156
.value("INTERNAL_ERROR", core::util::logging::LogLevel::kINTERNAL_ERROR)

tests/py/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ py_test(
2525
]
2626
)
2727

28+
# Following multi_gpu test is only targeted for multi-gpu configurations. It is not included in the test suite by default.
29+
py_test(
30+
name = "test_api_multi_gpu",
31+
srcs = [
32+
"test_api_multi_gpu.py",
33+
"model_test_case.py"
34+
] + select({
35+
":aarch64_linux": [
36+
"test_api_dla.py"
37+
],
38+
"//conditions:default" : []
39+
}),
40+
deps = [
41+
requirement("torchvision")
42+
]
43+
)
44+
2845
py_test(
2946
name = "test_to_backend_api",
3047
srcs = [

tests/py/test_api_multi_gpu.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
6+
from model_test_case import ModelTestCase
7+
8+
class TestCompile(MultiGpuTestCase):
9+
10+
def setUp(self):
11+
if not torch.cuda.device_count() > 1:
12+
raise ValueError("This test case is applicable for multi-gpu configurations only")
13+
14+
self.gpu_id = 1
15+
# Setting it up here so that all CUDA allocations are done on correct device
16+
trtorch.set_device(self.gpu_id)
17+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
18+
self.traced_model = torch.jit.trace(self.model, [self.input])
19+
self.scripted_model = torch.jit.script(self.model)
20+
21+
def test_compile_traced(self):
22+
compile_spec = {
23+
"input_shapes": [self.input.shape],
24+
"device": {
25+
"device_type": trtorch.DeviceType.GPU,
26+
"gpu_id": self.gpu_id,
27+
"dla_core": 0,
28+
"allow_gpu_fallback": False,
29+
"disable_tf32": False
30+
}
31+
}
32+
33+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
34+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
35+
self.assertTrue(same < 2e-3)
36+
37+
def test_compile_script(self):
38+
compile_spec = {
39+
"input_shapes": [self.input.shape],
40+
"device": {
41+
"device_type": trtorch.DeviceType.GPU,
42+
"gpu_id": self.gpu_id,
43+
"dla_core": 0,
44+
"allow_gpu_fallback": False,
45+
"disable_tf32": False
46+
}
47+
}
48+
49+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
50+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
51+
self.assertTrue(same < 2e-3)
52+
53+
54+
55+
def test_suite():
56+
suite = unittest.TestSuite()
57+
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
58+
59+
return suite
60+
61+
suite = test_suite()
62+
63+
runner = unittest.TextTestRunner()
64+
result = runner.run(suite)
65+
66+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)