@@ -1528,37 +1528,6 @@ def compile_subgraph(
15281528
15291529 from .decorators import disable
15301530
1531- if has_user_objects ():
1532- # NB: This is where we store possible user objects before running the graph
1533- # index_to_user_object_weakref is the function used in the graph to translate
1534- # the dynamo-generated index into the actual object passed to the compiled function.
1535- # We generate bytecode to store all user objects at the proper index in the below
1536- # call.
1537- codegen = PyCodegen (
1538- self .root_tx , root , overridden_sources = overridden_sources
1539- )
1540- codegen .add_push_null (
1541- lambda : codegen .load_import_from (
1542- torch ._dynamo .graph_bytecode_inputs .__name__ ,
1543- "store_user_object_weakrefs" ,
1544- )
1545- )
1546- tmp_vars = []
1547- for constructor in index_to_bytecode_constructor .values ():
1548- constructor (codegen )
1549- var_name = (
1550- self .new_var ()
1551- ) # keep alive any temp objects for the rest of the frame
1552- codegen .store (var_name )
1553- tmp_vars .append (var_name )
1554-
1555- for var_name in tmp_vars :
1556- codegen .append_output (codegen .create_load (var_name ))
1557-
1558- codegen .call_function (len (index_to_bytecode_constructor ), False )
1559- codegen .pop_top ()
1560- self .add_output_instructions (codegen .get_instructions ())
1561-
15621531 # to handle random calls
15631532 if len (self .random_calls ) > 0 :
15641533 random_calls_instructions = []
@@ -2343,6 +2312,33 @@ def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
23432312 assert self .root_tx is not None
23442313 cg = PyCodegen (self .root_tx )
23452314
2315+ if has_user_objects ():
2316+ # NB: This is where we store possible user objects before running the graph
2317+ # index_to_user_object_weakref is the function used in the graph to translate
2318+ # the dynamo-generated index into the actual object passed to the compiled function.
2319+ # We generate bytecode to store all user objects at the proper index in the below
2320+ # call.
2321+ cg .add_push_null (
2322+ lambda : cg .load_import_from (
2323+ torch ._dynamo .graph_bytecode_inputs .__name__ ,
2324+ "store_user_object_weakrefs" ,
2325+ )
2326+ )
2327+ tmp_vars = []
2328+ for constructor in index_to_bytecode_constructor .values ():
2329+ constructor (cg )
2330+ var_name = (
2331+ self .new_var ()
2332+ ) # keep alive any temp objects for the rest of the frame
2333+ cg .store (var_name )
2334+ tmp_vars .append (var_name )
2335+
2336+ for var_name in tmp_vars :
2337+ cg .append_output (cg .create_load (var_name ))
2338+
2339+ cg .call_function (len (index_to_bytecode_constructor ), False )
2340+ cg .pop_top ()
2341+
23462342 for idx , arg in enumerate (self .graphargs ):
23472343 self .export_metadata .graph_input_idx_to_local_source [idx ] = arg .source
23482344
@@ -3011,7 +3007,7 @@ def __init__(
30113007
30123008 self .tracked_tensor_or_symint_vt : OrderedSet [VariableTracker ] = OrderedSet ()
30133009
3014- def record_tensor_or_symint_vt (self , vt ):
3010+ def record_tensor_or_symint_vt (self , vt : VariableTracker ):
30153011 self .tracked_tensor_or_symint_vt .add (vt )
30163012
30173013 # preserve original meta if it is available
0 commit comments