Skip to content

Commit 46b3f91

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Add record/wait ops (pytorch#167151)
Pull Request resolved: pytorch#167151 Approved by: https://github.com/Lucaskabela ghstack dependencies: pytorch#167141
1 parent f7b7f40 commit 46b3f91

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

test/dynamo/test_streams.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
441441
)
442442

443443
@requires_cuda
444-
def test_run_opcheck(self):
444+
def test_run_opcheck_fork_join(self):
445445
from torch._dynamo.variables.streams import fork_stream, join_stream
446446
from torch.library import opcheck
447447

@@ -462,6 +462,30 @@ def test_run_opcheck(self):
462462
torch.accelerator.set_stream(original_stream)
463463
reset_user_object_tracking()
464464

465+
@requires_cuda
466+
def test_run_opcheck_wait_record(self):
467+
from torch._dynamo.variables.streams import record_event, wait_event
468+
from torch.library import opcheck
469+
470+
original_stream = torch.accelerator.current_stream()
471+
try:
472+
s0 = torch.Stream()
473+
s1 = torch.Stream()
474+
e0 = torch.Event()
475+
e1 = torch.Event()
476+
store_user_object_weakrefs(s0, s1, e0, e1)
477+
478+
sample_inputs = [
479+
(2, 0),
480+
(3, 1),
481+
]
482+
for args in sample_inputs:
483+
opcheck(wait_event, args)
484+
opcheck(record_event, args)
485+
finally:
486+
torch.accelerator.set_stream(original_stream)
487+
reset_user_object_tracking()
488+
465489

466490
if __name__ == "__main__":
467491
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/streams.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,28 @@
2828
Tensor = torch.Tensor
2929

3030

31+
def _get_stream_by_index(index: int) -> torch.Stream:
32+
stream = get_external_object_by_index(index)
33+
assert isinstance(stream, torch.Stream), (
34+
f"Fork/join stream expected a stream object at index {index}"
35+
)
36+
return stream
37+
38+
39+
def _get_event_by_index(index: int) -> torch.Event:
40+
event = get_external_object_by_index(index)
41+
assert isinstance(event, torch.Event), (
42+
f"Record/wait event expected an event object at index {index}"
43+
)
44+
return event
45+
46+
3147
@custom_op("streams::fork", mutates_args=())
3248
def fork_stream(
3349
from_index: int, # kept to make stream transitions clearer
3450
to_index: int,
3551
) -> None:
36-
stream = get_external_object_by_index(to_index)
37-
assert isinstance(stream, torch.Stream), (
38-
f"fork_stream expects a stream object at index {to_index}"
39-
)
40-
torch.accelerator.set_stream(stream)
52+
torch.accelerator.set_stream(_get_stream_by_index(to_index))
4153

4254

4355
@fork_stream.register_fake
@@ -50,11 +62,7 @@ def _(
5062

5163
@custom_op("streams::join", mutates_args=())
5264
def join_stream(from_index: int, to_index: int) -> None:
53-
stream = get_external_object_by_index(to_index)
54-
assert isinstance(stream, torch.Stream), (
55-
f"join_stream expects a stream object at index {to_index}"
56-
)
57-
torch.accelerator.set_stream(stream)
65+
torch.accelerator.set_stream(_get_stream_by_index(to_index))
5866

5967

6068
@join_stream.register_fake
@@ -65,6 +73,36 @@ def _(
6573
pass
6674

6775

76+
@custom_op("streams::record_event", mutates_args=())
77+
def record_event(event_index: int, stream_index: int) -> None:
78+
event = _get_event_by_index(event_index)
79+
stream = _get_stream_by_index(stream_index)
80+
stream.record_event(event)
81+
82+
83+
@record_event.register_fake
84+
def _(
85+
event_index: int,
86+
stream_index: int,
87+
) -> None:
88+
pass
89+
90+
91+
@custom_op("streams::wait_event", mutates_args=())
92+
def wait_event(event_index: int, stream_index: int) -> None:
93+
event = _get_event_by_index(event_index)
94+
stream = _get_stream_by_index(stream_index)
95+
stream.wait_event(event)
96+
97+
98+
@wait_event.register_fake
99+
def _(
100+
event_index: int,
101+
stream_index: int,
102+
) -> None:
103+
pass
104+
105+
68106
class SymbolicStreamState:
69107
"""Track the currently entered stream if any"""
70108

0 commit comments

Comments
 (0)