Skip to content

Commit 7b423c2

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Mark stream ops as side effectful (pytorch#167152)
Pull Request resolved: pytorch#167152 Approved by: https://github.com/Lucaskabela ghstack dependencies: pytorch#167141, pytorch#167151
1 parent 46b3f91 commit 7b423c2

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

test/dynamo/test_streams.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,22 @@ def test_run_opcheck_wait_record(self):
486486
torch.accelerator.set_stream(original_stream)
487487
reset_user_object_tracking()
488488

489+
def test_is_marked_side_effectful(self):
490+
self.assertIn(
491+
torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions
492+
)
493+
self.assertIn(
494+
torch.ops.streams.join.default, torch.fx.node._side_effectful_functions
495+
)
496+
self.assertIn(
497+
torch.ops.streams.wait_event.default,
498+
torch.fx.node._side_effectful_functions,
499+
)
500+
self.assertIn(
501+
torch.ops.streams.record_event.default,
502+
torch.fx.node._side_effectful_functions,
503+
)
504+
489505

490506
if __name__ == "__main__":
491507
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/streams.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch._dynamo.variables.dicts import ConstDictVariable
77
from torch._dynamo.variables.lists import TupleVariable
8-
from torch.fx import Proxy
8+
from torch.fx import has_side_effect, Proxy
99

1010
from .. import graph_break_hints
1111
from ..bytecode_transformation import create_call_function
@@ -60,6 +60,9 @@ def _(
6060
pass
6161

6262

63+
has_side_effect(torch.ops.streams.fork.default)
64+
65+
6366
@custom_op("streams::join", mutates_args=())
6467
def join_stream(from_index: int, to_index: int) -> None:
6568
torch.accelerator.set_stream(_get_stream_by_index(to_index))
@@ -73,6 +76,9 @@ def _(
7376
pass
7477

7578

79+
has_side_effect(torch.ops.streams.join.default)
80+
81+
7682
@custom_op("streams::record_event", mutates_args=())
7783
def record_event(event_index: int, stream_index: int) -> None:
7884
event = _get_event_by_index(event_index)
@@ -88,6 +94,9 @@ def _(
8894
pass
8995

9096

97+
has_side_effect(torch.ops.streams.record_event.default)
98+
99+
91100
@custom_op("streams::wait_event", mutates_args=())
92101
def wait_event(event_index: int, stream_index: int) -> None:
93102
event = _get_event_by_index(event_index)
@@ -103,6 +112,9 @@ def _(
103112
pass
104113

105114

115+
has_side_effect(torch.ops.streams.wait_event.default)
116+
117+
106118
class SymbolicStreamState:
107119
"""Track the currently entered stream if any"""
108120

0 commit comments

Comments
 (0)