@@ -92,7 +92,6 @@ def __init__(
9292 target_values = target_values , initial_values = initial_values , ** kwargs
9393 )
9494 self .device = device
95- self .set_stream_id = get_interface_for_device (self .device )._set_stream_by_id
9695
9796 def enter (self , tx : "InstructionTranslator" ) -> "VariableTracker" :
9897 # to stream, from stream is the order of the arguments
@@ -124,7 +123,7 @@ def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
124123
125124 def _target_stream_proxies (self ) -> tuple [Proxy , Proxy ]:
126125 return StreamContextVariable ._extract_stream_properties (
127- self .target_values [0 ].as_proxy ()
126+ self ._get_target_values () [0 ].as_proxy ()
128127 )
129128
130129 @staticmethod
@@ -152,6 +151,15 @@ def _get_current_stream(
152151 )
153152 return current_stream
154153
154+ def _get_target_values (self ) -> list ["StreamVariable" ]:
155+ # We need this to be overridable, since StreamVariable does
156+ # not store target values (it does not require any arguments)
157+ # and captures the current stream at the time of entering the context
158+ return self .target_values
159+
160+ def supports_graph_breaks (self ) -> bool :
161+ return True
162+
155163
156164class StreamVariable (StreamContextVariable ):
157165 """Represents the device-agnostic torch.Stream class"""
@@ -168,9 +176,7 @@ def __init__(
168176 assert value .device .type == device .type , (
169177 "stream value is not equal to the passed device"
170178 )
171- super ().__init__ (
172- target_values = [self ], initial_values = None , device = device , ** kwargs
173- )
179+ super ().__init__ (target_values = [], initial_values = None , device = device , ** kwargs )
174180 self .proxy = proxy
175181 self .value = value
176182 # pyrefly: ignore [read-only]
@@ -233,18 +239,23 @@ def call_method(
233239 return super ().call_method (tx , name , args , kwargs )
234240
235241 def enter (self , tx : "InstructionTranslator" ) -> "VariableTracker" :
236- # NB: Set initial values and target values when we enter
242+ # NB: Set initial values when we enter
237243 # Don't do this at object creation, as we need to record the current stream
238244 # at the time the context is entered.
239245 self .initial_values = [
240246 StreamContextVariable ._get_current_stream (self .device , tx )
241247 ]
242- self .target_values = [self ]
243248 return super ().enter (tx )
244249
245250 def as_proxy (self ) -> Proxy :
246251 return self .proxy
247252
253+ def module_name (self ) -> str :
254+ return "torch._C"
255+
256+ def fn_name (self ) -> str :
257+ return "Stream"
258+
248259 def reconstruct (self , codegen : "PyCodegen" ) -> None :
249260 # If we got here, this stream is fully subsumed by the graph - this means it is
250261 # not an input or global
@@ -259,6 +270,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None:
259270 name = codegen .tx .output .install_global_by_id (prefix , self .value )
260271 codegen .append_output (codegen .create_load_global (name , add = True ))
261272
273+ def _get_target_values (self ) -> list ["StreamVariable" ]:
274+ return [self ]
275+
262276
263277class EventVariable (VariableTracker ):
264278 def __init__ (self , proxy : Proxy , value : torch .Event , ** kwargs : Any ) -> None :
0 commit comments