@@ -52,10 +52,21 @@ def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
5252 )
5353
5454
55+ def _codegen_current_stream (device : torch .device , cg : "PyCodegen" ) -> None :
56+ cg .add_push_null (
57+ lambda : cg .load_import_from (
58+ torch ._dynamo .graph_bytecode_inputs .__name__ , # type: ignore[implicit-imports]
59+ "stash_graph_created_object" ,
60+ )
61+ )
62+ cg (CurrentStreamSource (device ))
63+ cg .extend_output (create_call_function (1 , False ))
64+
65+
5566def get_current_stream (device : torch .device ) -> int :
56- stream = torch .accelerator .current_stream ()
67+ stream = torch .accelerator .current_stream (device )
5768 return register_graph_created_object (
58- stream , lambda _ , cg : cg ( CurrentStreamSource ( device ) )
69+ stream , lambda _ , cg : _codegen_current_stream ( device , cg )
5970 )
6071
6172
@@ -362,6 +373,12 @@ def make_construct_in_graph_stream_fn(
362373 args : TupleVariable , kwargs : ConstDictVariable
363374 ) -> Callable [[int , "PyCodegen" ], None ]:
364375 def fn (index : int , codegen : "PyCodegen" ) -> None :
376+ codegen .add_push_null (
377+ lambda : codegen .load_import_from (
378+ torch ._dynamo .graph_bytecode_inputs .__name__ , # type: ignore[implicit-imports]
379+ "stash_graph_created_object" ,
380+ )
381+ )
365382 codegen .add_push_null (
366383 lambda : codegen .load_import_from (
367384 torch ._dynamo .utils .__name__ , "build_stream"
@@ -370,6 +387,7 @@ def fn(index: int, codegen: "PyCodegen") -> None:
370387 codegen (args )
371388 codegen (kwargs )
372389 codegen .extend_output (create_call_function (2 , False ))
390+ codegen .extend_output (create_call_function (1 , False ))
373391
374392 return fn
375393
@@ -473,6 +491,12 @@ def make_construct_in_graph_event_fn(
473491 args : TupleVariable , kwargs : ConstDictVariable
474492 ) -> Callable [[int , "PyCodegen" ], None ]:
475493 def fn (index : int , codegen : "PyCodegen" ) -> None :
494+ codegen .add_push_null (
495+ lambda : codegen .load_import_from (
496+ torch ._dynamo .graph_bytecode_inputs .__name__ , # type: ignore[implicit-imports]
497+ "stash_graph_created_object" ,
498+ )
499+ )
476500 codegen .add_push_null (
477501 lambda : codegen .load_import_from (
478502 torch ._dynamo .utils .__name__ , "build_event"
@@ -481,6 +505,7 @@ def fn(index: int, codegen: "PyCodegen") -> None:
481505 codegen (args )
482506 codegen (kwargs )
483507 codegen .extend_output (create_call_function (2 , False ))
508+ codegen .extend_output (create_call_function (1 , False ))
484509
485510 return fn
486511
0 commit comments