Skip to content

Commit 39f5e0e

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Move user object bytecode generation after calling user compiler (pytorch#167704)
This move needs to occur in order to allow AOTAutograd to indicate if more streams/events need to be created for the backward. Pull Request resolved: pytorch#167704 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#167513
1 parent 6eb71ce commit 39f5e0e

File tree

1 file changed

+28
-32
lines changed

1 file changed

+28
-32
lines changed

torch/_dynamo/output_graph.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,37 +1528,6 @@ def compile_subgraph(
15281528

15291529
from .decorators import disable
15301530

1531-
if has_user_objects():
1532-
# NB: This is where we store possible user objects before running the graph
1533-
# index_to_user_object_weakref is the function used in the graph to translate
1534-
# the dynamo-generated index into the actual object passed to the compiled function.
1535-
# We generate bytecode to store all user objects at the proper index in the below
1536-
# call.
1537-
codegen = PyCodegen(
1538-
self.root_tx, root, overridden_sources=overridden_sources
1539-
)
1540-
codegen.add_push_null(
1541-
lambda: codegen.load_import_from(
1542-
torch._dynamo.graph_bytecode_inputs.__name__,
1543-
"store_user_object_weakrefs",
1544-
)
1545-
)
1546-
tmp_vars = []
1547-
for constructor in index_to_bytecode_constructor.values():
1548-
constructor(codegen)
1549-
var_name = (
1550-
self.new_var()
1551-
) # keep alive any temp objects for the rest of the frame
1552-
codegen.store(var_name)
1553-
tmp_vars.append(var_name)
1554-
1555-
for var_name in tmp_vars:
1556-
codegen.append_output(codegen.create_load(var_name))
1557-
1558-
codegen.call_function(len(index_to_bytecode_constructor), False)
1559-
codegen.pop_top()
1560-
self.add_output_instructions(codegen.get_instructions())
1561-
15621531
# to handle random calls
15631532
if len(self.random_calls) > 0:
15641533
random_calls_instructions = []
@@ -2343,6 +2312,33 @@ def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
23432312
assert self.root_tx is not None
23442313
cg = PyCodegen(self.root_tx)
23452314

2315+
if has_user_objects():
2316+
# NB: This is where we store possible user objects before running the graph
2317+
# index_to_user_object_weakref is the function used in the graph to translate
2318+
# the dynamo-generated index into the actual object passed to the compiled function.
2319+
# We generate bytecode to store all user objects at the proper index in the below
2320+
# call.
2321+
cg.add_push_null(
2322+
lambda: cg.load_import_from(
2323+
torch._dynamo.graph_bytecode_inputs.__name__,
2324+
"store_user_object_weakrefs",
2325+
)
2326+
)
2327+
tmp_vars = []
2328+
for constructor in index_to_bytecode_constructor.values():
2329+
constructor(cg)
2330+
var_name = (
2331+
self.new_var()
2332+
) # keep alive any temp objects for the rest of the frame
2333+
cg.store(var_name)
2334+
tmp_vars.append(var_name)
2335+
2336+
for var_name in tmp_vars:
2337+
cg.append_output(cg.create_load(var_name))
2338+
2339+
cg.call_function(len(index_to_bytecode_constructor), False)
2340+
cg.pop_top()
2341+
23462342
for idx, arg in enumerate(self.graphargs):
23472343
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
23482344

@@ -3011,7 +3007,7 @@ def __init__(
30113007

30123008
self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet()
30133009

3014-
def record_tensor_or_symint_vt(self, vt):
3010+
def record_tensor_or_symint_vt(self, vt: VariableTracker):
30153011
self.tracked_tensor_or_symint_vt.add(vt)
30163012

30173013
# preserve original meta if it is available

0 commit comments

Comments
 (0)