Skip to content

Commit e8d887a

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Support streams as contexts (pytorch#164507)
Pull Request resolved: pytorch#164507 Approved by: https://github.com/williamwen42 ghstack dependencies: pytorch#162903, pytorch#164343, pytorch#164344
1 parent 774abb0 commit e8d887a

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_streams.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
class TestStream(TestCase):
8+
@skipIfTorchDynamo()
89
def test_stream_create(self):
910
stream = torch.Stream(device="openreg")
1011
self.assertEqual(stream.device_index, torch.openreg.current_device())
@@ -24,6 +25,7 @@ def test_stream_create(self):
2425
)
2526
self.assertEqual(stream, stream1)
2627

28+
@skipIfTorchDynamo()
2729
def test_stream_context(self):
2830
with torch.Stream(device="openreg:1") as stream:
2931
self.assertEqual(torch.accelerator.current_stream(), stream)
@@ -40,6 +42,7 @@ def test_stream_switch(self):
4042
current_stream = torch.accelerator.current_stream()
4143
self.assertEqual(current_stream, stream2)
4244

45+
@skipIfTorchDynamo()
4346
def test_stream_synchronize(self):
4447
stream = torch.Stream(device="openreg:1")
4548
self.assertEqual(True, stream.query())
@@ -49,12 +52,14 @@ def test_stream_synchronize(self):
4952
stream.synchronize()
5053
self.assertEqual(True, stream.query())
5154

55+
@skipIfTorchDynamo()
5256
def test_stream_repr(self):
5357
stream = torch.Stream(device="openreg:1")
5458
self.assertTrue(
5559
"torch.Stream device_type=openreg, device_index=1" in repr(stream)
5660
)
5761

62+
@skipIfTorchDynamo()
5863
def test_stream_wait_stream(self):
5964
stream_1 = torch.Stream(device="openreg:0")
6065
stream_2 = torch.Stream(device="openreg:1")

torch/_dynamo/variables/streams.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def __init__(
9292
target_values=target_values, initial_values=initial_values, **kwargs
9393
)
9494
self.device = device
95-
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
9695

9796
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
9897
# to stream, from stream is the order of the arguments
@@ -124,7 +123,7 @@ def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
124123

125124
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
126125
return StreamContextVariable._extract_stream_properties(
127-
self.target_values[0].as_proxy()
126+
self._get_target_values()[0].as_proxy()
128127
)
129128

130129
@staticmethod
@@ -152,6 +151,15 @@ def _get_current_stream(
152151
)
153152
return current_stream
154153

154+
def _get_target_values(self) -> list["StreamVariable"]:
155+
# We need this to be overridable, since StreamVariable does
156+
# not store target values (it does not require any arguments)
157+
# and captures the current stream at the time of entering the context
158+
return self.target_values
159+
160+
def supports_graph_breaks(self) -> bool:
161+
return True
162+
155163

156164
class StreamVariable(StreamContextVariable):
157165
"""Represents the device-agnostic torch.Stream class"""
@@ -168,9 +176,7 @@ def __init__(
168176
assert value.device.type == device.type, (
169177
"stream value is not equal to the passed device"
170178
)
171-
super().__init__(
172-
target_values=[self], initial_values=None, device=device, **kwargs
173-
)
179+
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
174180
self.proxy = proxy
175181
self.value = value
176182
# pyrefly: ignore [read-only]
@@ -233,18 +239,23 @@ def call_method(
233239
return super().call_method(tx, name, args, kwargs)
234240

235241
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
236-
# NB: Set initial values and target values when we enter
242+
# NB: Set initial values when we enter
237243
# Don't do this at object creation, as we need to record the current stream
238244
# at the time the context is entered.
239245
self.initial_values = [
240246
StreamContextVariable._get_current_stream(self.device, tx)
241247
]
242-
self.target_values = [self]
243248
return super().enter(tx)
244249

245250
def as_proxy(self) -> Proxy:
246251
return self.proxy
247252

253+
def module_name(self) -> str:
254+
return "torch._C"
255+
256+
def fn_name(self) -> str:
257+
return "Stream"
258+
248259
def reconstruct(self, codegen: "PyCodegen") -> None:
249260
# If we got here, this stream is fully subsumed by the graph - this means it is
250261
# not an input or global
@@ -259,6 +270,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None:
259270
name = codegen.tx.output.install_global_by_id(prefix, self.value)
260271
codegen.append_output(codegen.create_load_global(name, add=True))
261272

273+
def _get_target_values(self) -> list["StreamVariable"]:
274+
return [self]
275+
262276

263277
class EventVariable(VariableTracker):
264278
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:

0 commit comments

Comments
 (0)