Skip to content

Commit 693cc29

Browse files
Fixed missing writeback copy opperation in insert_write_back_for_buffers_pass for the case of copying data directly from one input to another.
Also converted the list comprehension to a for loop for readablity.
1 parent fedb035 commit 693cc29

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,19 @@ def insert_write_back_for_buffers_pass(
100100
input_name_to_node[lifted_node] = input_node
101101

102102
# Grab the mutable buffer nodes in the outputs,
103-
mutated_outputs: List[Optional[str]] = [
104-
(
105-
out_spec.target
106-
if out_spec.kind
107-
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
108-
and out_spec.arg.name
109-
not in {
110-
val.name for val in input_name_to_node.values()
111-
} # if the output arg is the input value then all operations on it are in-place so theres no need to add a copy_ node
112-
else None
113-
)
114-
for out_spec in ep.graph_signature.output_specs
115-
]
103+
mutated_outputs: List[Optional[str]] = []
104+
for out_spec in ep.graph_signature.output_specs:
105+
# if the output arg is the input value then all operations on it are in-place
106+
# so there's no need to add a copy_ node
107+
if (out_spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) and
108+
# explicitly check if target exists (it should always be there)
109+
out_spec.target in input_name_to_node and
110+
# if the arg and target are not the same, we add a copy_ node.
111+
out_spec.arg.name != input_name_to_node[out_spec.target].name
112+
):
113+
mutated_outputs.append(out_spec.target)
114+
else:
115+
mutated_outputs.append(None)
116116

117117
# insert the copy ops and update the outputs
118118
buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node)

0 commit comments

Comments
 (0)