Skip to content

Commit 23669d0

Browse files
mlazospytorchmergebot
authored andcommitted
[user-cuda-streams] Add cuda streams test suite (pytorch#162901)
Pull Request resolved: pytorch#162901 Approved by: https://github.com/williamwen42 ghstack dependencies: pytorch#162903, pytorch#164343, pytorch#164344, pytorch#164507
1 parent e8d887a commit 23669d0

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

test/dynamo/test_streams.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Owner(s): ["module: dynamo"]
2+
3+
import torch
4+
import torch._dynamo.test_case
5+
import torch._dynamo.testing
6+
from torch.testing._internal.common_utils import requires_cuda
7+
8+
9+
class TestStreams(torch._dynamo.test_case.TestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
super().setUpClass()
13+
14+
@classmethod
15+
def tearDownClass(cls):
16+
super().tearDownClass()
17+
18+
@requires_cuda
19+
def test_run_opcheck(self):
20+
from torch._dynamo.variables.streams import fork_stream, join_stream
21+
from torch.library import opcheck
22+
23+
sample_inputs = [
24+
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
25+
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
26+
]
27+
for args in sample_inputs:
28+
opcheck(fork_stream, args)
29+
opcheck(join_stream, args)
30+
31+
32+
if __name__ == "__main__":
33+
from torch._dynamo.test_case import run_tests
34+
35+
run_tests()

0 commit comments

Comments
 (0)