1515    OutputKind ,
1616    OutputSpec ,
1717)
18+ from  torch .export .graph_signature  import  TensorArgument 
1819from  torch .utils  import  _pytree  as  pytree 
1920
2021
@@ -73,20 +74,21 @@ def insert_write_back_for_buffers_pass(
7374    ep : ExportedProgram ,
7475) ->  Tuple [torch .fx .GraphModule , ExportGraphSignature ]:
7576    gm : torch .fx .GraphModule  =  ep .graph_module 
76-     lifted_inputs : List [Optional [str ]] =  [
77-         (
78-             in_spec .target 
79-             if  in_spec .kind 
80-             in  (
81-                 InputKind .BUFFER ,
82-                 InputKind .CONSTANT_TENSOR ,
83-                 InputKind .PARAMETER ,
84-                 InputKind .CUSTOM_OBJ ,
85-             )
86-             else  None 
87-         )
88-         for  in_spec  in  ep .graph_signature .input_specs 
89-     ]
77+     lifted_inputs : List [Optional [str ]] =  []
78+     for  in_spec  in  ep .graph_signature .input_specs :
79+         if  in_spec .kind  in  (
80+             InputKind .BUFFER ,
81+             InputKind .CONSTANT_TENSOR ,
82+             InputKind .PARAMETER ,
83+             InputKind .CUSTOM_OBJ ,
84+         ):
85+             lifted_inputs .append (in_spec .target )
86+         elif  in_spec .kind  is  InputKind .USER_INPUT  and  isinstance (
87+             in_spec .arg , TensorArgument 
88+         ):
89+             lifted_inputs .append (in_spec .arg .name )
90+         else :
91+             lifted_inputs .append (None )
9092
9193    input_name_to_node : Dict [str , torch .fx .Node ] =  {}
9294
@@ -101,7 +103,8 @@ def insert_write_back_for_buffers_pass(
101103    mutated_outputs : List [Optional [str ]] =  [
102104        (
103105            out_spec .target 
104-             if  out_spec .kind  in  (OutputKind .BUFFER_MUTATION ,)
106+             if  out_spec .kind 
107+             in  (OutputKind .BUFFER_MUTATION , OutputKind .USER_INPUT_MUTATION )
105108            and  out_spec .arg .name 
106109            not  in   {
107110                val .name  for  val  in  input_name_to_node .values ()
@@ -121,7 +124,10 @@ def insert_write_back_for_buffers_pass(
121124    new_output_specs : List [OutputSpec ] =  []
122125    i  =  0 
123126    for  output_spec  in  ep .graph_signature .output_specs :
124-         if  output_spec .kind  ==  OutputKind .BUFFER_MUTATION :
127+         if  output_spec .kind  in  (
128+             OutputKind .BUFFER_MUTATION ,
129+             OutputKind .USER_INPUT_MUTATION ,
130+         ):
125131            output_spec .arg .name  =  buffer_output_nodes [i ].name 
126132            i  +=  1 
127133        new_output_specs .append (output_spec )
0 commit comments