Skip to content

Commit f7b7f40

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Enable stream ops to work in eager (pytorch#167141)
Pull Request resolved: pytorch#167141 Approved by: https://github.com/Lucaskabela
1 parent 91337ae commit f7b7f40

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

test/dynamo/test_streams.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import torch
88
import torch._dynamo.test_case
99
import torch._dynamo.testing
10+
from torch._dynamo.graph_bytecode_inputs import (
11+
reset_user_object_tracking,
12+
store_user_object_weakrefs,
13+
)
1014
from torch._dynamo.testing import extract_graph, remove_trailing_space
1115
from torch.testing._internal.common_cuda import TEST_MULTIGPU
1216
from torch.testing._internal.common_utils import requires_cuda
@@ -441,13 +445,22 @@ def test_run_opcheck(self):
441445
from torch._dynamo.variables.streams import fork_stream, join_stream
442446
from torch.library import opcheck
443447

444-
sample_inputs = [
445-
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
446-
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
447-
]
448-
for args in sample_inputs:
449-
opcheck(fork_stream, args)
450-
opcheck(join_stream, args)
448+
original_stream = torch.accelerator.current_stream()
449+
try:
450+
s0 = torch.Stream()
451+
s1 = torch.Stream()
452+
store_user_object_weakrefs(s0, s1)
453+
454+
sample_inputs = [
455+
(0, 1),
456+
(1, 0),
457+
]
458+
for args in sample_inputs:
459+
opcheck(fork_stream, args)
460+
opcheck(join_stream, args)
461+
finally:
462+
torch.accelerator.set_stream(original_stream)
463+
reset_user_object_tracking()
451464

452465

453466
if __name__ == "__main__":

torch/_dynamo/variables/streams.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .. import graph_break_hints
1111
from ..bytecode_transformation import create_call_function
1212
from ..exc import TYPE_CHECKING, unimplemented_v2
13+
from ..graph_bytecode_inputs import get_external_object_by_index
1314
from .base import VariableTracker
1415
from .constant import ConstantVariable
1516
from .ctx_manager import FxTracebackAnnotateVariable
@@ -29,40 +30,37 @@
2930

3031
@custom_op("streams::fork", mutates_args=())
3132
def fork_stream(
32-
from_index: int,
33-
from_device: torch.device,
33+
from_index: int, # kept to make stream transitions clearer
3434
to_index: int,
35-
to_device: torch.device,
3635
) -> None:
37-
pass
36+
stream = get_external_object_by_index(to_index)
37+
assert isinstance(stream, torch.Stream), (
38+
f"fork_stream expects a stream object at index {to_index}"
39+
)
40+
torch.accelerator.set_stream(stream)
3841

3942

4043
@fork_stream.register_fake
4144
def _(
42-
from_index: int,
43-
from_device: torch.device,
45+
from_index: int, # kept to make stream transitions clearer
4446
to_index: int,
45-
to_device: torch.device,
4647
) -> None:
4748
pass
4849

4950

5051
@custom_op("streams::join", mutates_args=())
51-
def join_stream(
52-
from_index: int,
53-
from_device: torch.device,
54-
to_index: int,
55-
to_device: torch.device,
56-
) -> None:
57-
pass
52+
def join_stream(from_index: int, to_index: int) -> None:
53+
stream = get_external_object_by_index(to_index)
54+
assert isinstance(stream, torch.Stream), (
55+
f"join_stream expects a stream object at index {to_index}"
56+
)
57+
torch.accelerator.set_stream(stream)
5858

5959

6060
@join_stream.register_fake
6161
def _(
6262
from_index: int,
63-
from_device: torch.device,
6463
to_index: int,
65-
to_device: torch.device,
6664
) -> None:
6765
pass
6866

0 commit comments

Comments
 (0)