Skip to content

Commit c1ab5db

Browse files
committed
tests(multi_gpu): Refactor the multi gpu test and remove redundant code
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent ac9f2d4 commit c1ab5db

File tree

3 files changed

+24
-45
lines changed

3 files changed

+24
-45
lines changed

tests/py/BUILD

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ py_test(
1515
"test_api.py",
1616
"model_test_case.py"
1717
] + select({
18-
":aarch64_linux": [
19-
"test_api_dla.py"
20-
],
18+
":aarch64_linux": [
19+
"test_api_dla.py"
20+
],
2121
"//conditions:default" : []
2222
}),
2323
deps = [
@@ -27,14 +27,11 @@ py_test(
2727

2828
# Following multi_gpu test is only targeted for multi-gpu configurations. It is not included in the test suite by default.
2929
py_test(
30-
name = "test_api_multi_gpu",
30+
name = "test_multi_gpu",
3131
srcs = [
32-
"test_api_multi_gpu.py",
32+
"test_multi_gpu.py",
3333
"model_test_case.py"
34-
] + select({
35-
":aarch64_linux": [
36-
"test_api_dla.py"
37-
],
34+
],
3835
"//conditions:default" : []
3936
}),
4037
deps = [

tests/py/multi_gpu_test_case.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

tests/py/test_api_multi_gpu.py renamed to tests/py/test_multi_gpu.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,61 +3,64 @@
33
import torch
44
import torchvision.models as models
55

6-
from multi_gpu_test_case import MultiGpuTestCase
7-
8-
gpu_id = 1
9-
class TestCompile(MultiGpuTestCase):
6+
from model_test_case import ModelTestCase
107

8+
class TestMultiGpuSwitching(ModelTestCase):
119
def setUp(self):
12-
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
10+
if torch.cuda.device_count() < 2:
11+
self.fail("Test is not relevant for this platform since number of available CUDA devices is less than 2")
12+
13+
trtorch.set_device(0)
14+
self.target_gpu = 1
15+
self.input = torch.randn((1, 3, 224, 224)).to("cuda:1")
16+
self.model = self.model.to("cuda:1")
1317
self.traced_model = torch.jit.trace(self.model, [self.input])
1418
self.scripted_model = torch.jit.script(self.model)
1519

1620
def test_compile_traced(self):
21+
trtorch.set_device(0)
1722
compile_spec = {
1823
"input_shapes": [self.input.shape],
1924
"device": {
2025
"device_type": trtorch.DeviceType.GPU,
21-
"gpu_id": gpu_id,
26+
"gpu_id": self.target_gpu,
2227
"dla_core": 0,
2328
"allow_gpu_fallback": False,
2429
"disable_tf32": False
2530
}
2631
}
2732

2833
trt_mod = trtorch.compile(self.traced_model, compile_spec)
34+
trtorch.set_device(self.target_gpu)
2935
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
36+
trtorch.set_device(0)
3037
self.assertTrue(same < 2e-3)
3138

3239
def test_compile_script(self):
40+
trtorch.set_device(0)
3341
compile_spec = {
3442
"input_shapes": [self.input.shape],
3543
"device": {
3644
"device_type": trtorch.DeviceType.GPU,
37-
"gpu_id": gpu_id,
45+
"gpu_id": self.target_gpu,
3846
"dla_core": 0,
3947
"allow_gpu_fallback": False,
4048
"disable_tf32": False
4149
}
4250
}
4351

4452
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
53+
trtorch.set_device(self.target_gpu)
4554
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
55+
trtorch.set_device(0)
4656
self.assertTrue(same < 2e-3)
4757

48-
49-
5058
def test_suite():
5159
suite = unittest.TestSuite()
52-
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
60+
suite.addTest(TestMultiGpuSwitching.parametrize(TestMultiGpuSwitching, model=models.resnet18(pretrained=True)))
5361

5462
return suite
5563

56-
if not torch.cuda.device_count() > 1:
57-
raise ValueError("This test case is applicable for multi-gpu configurations only")
58-
59-
# Setting it up here so that all CUDA allocations are done on correct device
60-
trtorch.set_device(gpu_id)
6164
suite = test_suite()
6265

6366
runner = unittest.TextTestRunner()

0 commit comments

Comments
 (0)