Skip to content

Commit 7392106

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Stash graph created objects in keep_alive list for backwards (pytorch#167705)
Pull Request resolved: pytorch#167705 Approved by: https://github.com/williamwen42
1 parent 01f94d4 commit 7392106

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

torch/_dynamo/graph_bytecode_inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def has_user_objects() -> bool:
2525
return bool(index_to_bytecode_constructor)
2626

2727

28+
def stash_graph_created_object(obj: Any) -> Any:
29+
keep_alive.append(obj)
30+
return obj
31+
32+
2833
def get_external_object_by_index(index: int) -> Any:
2934
assert index in index_to_external_object_weakref, (
3035
"Index not registered in index_to_user_object_weakref"

torch/_dynamo/output_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2324,12 +2324,14 @@ def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
23242324
"store_user_object_weakrefs",
23252325
)
23262326
)
2327+
23272328
tmp_vars = []
23282329
for constructor in index_to_bytecode_constructor.values():
23292330
constructor(cg)
23302331
var_name = (
23312332
self.new_var()
2332-
) # keep alive any temp objects for the rest of the frame
2333+
) # keep alive any user objects for the rest of the frame
2334+
# TODO: we could omit this for objects we create but shouldn't be too much overhead for now
23332335
cg.store(var_name)
23342336
tmp_vars.append(var_name)
23352337

torch/_dynamo/variables/streams.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,21 @@ def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
5252
)
5353

5454

55+
def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None:
56+
cg.add_push_null(
57+
lambda: cg.load_import_from(
58+
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
59+
"stash_graph_created_object",
60+
)
61+
)
62+
cg(CurrentStreamSource(device))
63+
cg.extend_output(create_call_function(1, False))
64+
65+
5566
def get_current_stream(device: torch.device) -> int:
56-
stream = torch.accelerator.current_stream()
67+
stream = torch.accelerator.current_stream(device)
5768
return register_graph_created_object(
58-
stream, lambda _, cg: cg(CurrentStreamSource(device))
69+
stream, lambda _, cg: _codegen_current_stream(device, cg)
5970
)
6071

6172

@@ -362,6 +373,12 @@ def make_construct_in_graph_stream_fn(
362373
args: TupleVariable, kwargs: ConstDictVariable
363374
) -> Callable[[int, "PyCodegen"], None]:
364375
def fn(index: int, codegen: "PyCodegen") -> None:
376+
codegen.add_push_null(
377+
lambda: codegen.load_import_from(
378+
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
379+
"stash_graph_created_object",
380+
)
381+
)
365382
codegen.add_push_null(
366383
lambda: codegen.load_import_from(
367384
torch._dynamo.utils.__name__, "build_stream"
@@ -370,6 +387,7 @@ def fn(index: int, codegen: "PyCodegen") -> None:
370387
codegen(args)
371388
codegen(kwargs)
372389
codegen.extend_output(create_call_function(2, False))
390+
codegen.extend_output(create_call_function(1, False))
373391

374392
return fn
375393

@@ -473,6 +491,12 @@ def make_construct_in_graph_event_fn(
473491
args: TupleVariable, kwargs: ConstDictVariable
474492
) -> Callable[[int, "PyCodegen"], None]:
475493
def fn(index: int, codegen: "PyCodegen") -> None:
494+
codegen.add_push_null(
495+
lambda: codegen.load_import_from(
496+
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
497+
"stash_graph_created_object",
498+
)
499+
)
476500
codegen.add_push_null(
477501
lambda: codegen.load_import_from(
478502
torch._dynamo.utils.__name__, "build_event"
@@ -481,6 +505,7 @@ def fn(index: int, codegen: "PyCodegen") -> None:
481505
codegen(args)
482506
codegen(kwargs)
483507
codegen.extend_output(create_call_function(2, False))
508+
codegen.extend_output(create_call_function(1, False))
484509

485510
return fn
486511

0 commit comments

Comments
 (0)