@@ -100,19 +100,19 @@ def insert_write_back_for_buffers_pass(
100100 input_name_to_node [lifted_node ] = input_node
101101
102102 # Grab the mutable buffer nodes in the outputs,
103- mutated_outputs : List [Optional [str ]] = [
104- (
105- out_spec . target
106- if out_spec . kind
107- in (OutputKind .BUFFER_MUTATION , OutputKind .USER_INPUT_MUTATION )
108- and out_spec . arg . name
109- not in {
110- val . name for val in input_name_to_node . values ()
111- } # if the output arg is the input value then all operations on it are in-place so theres no need to add a copy_ node
112- else None
113- )
114- for out_spec in ep . graph_signature . output_specs
115- ]
103+ mutated_outputs : List [Optional [str ]] = []
104+ for out_spec in ep . graph_signature . output_specs :
105+ # if the output arg is the input value then all operations on it are in-place
106+ # so there's no need to add a copy_ node
107+ if ( out_spec . kind in (OutputKind .BUFFER_MUTATION , OutputKind .USER_INPUT_MUTATION ) and
108+ # explicitly check if target exists (it should always be there)
109+ out_spec . target in input_name_to_node and
110+ # if the arg and target are not the same, we add a copy_ node.
111+ out_spec . arg . name != input_name_to_node [ out_spec . target ]. name
112+ ):
113+ mutated_outputs . append ( out_spec . target )
114+ else :
115+ mutated_outputs . append ( None )
116116
117117 # insert the copy ops and update the outputs
118118 buffer_output_nodes = _insert_copy (gm , mutated_outputs , input_name_to_node )
0 commit comments