Skip to content

Commit ae9d1eb

Browse files
use graph.output_node
Differential Revision: D77247219 Pull Request resolved: #12139
1 parent 22b9e59 commit ae9d1eb

File tree

9 files changed

+19
-69
lines changed

9 files changed

+19
-69
lines changed

backends/test/harness/tester.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,7 @@ def _calculate_reference_output(
416416
"""
417417

418418
# Locate the output node.
419-
output_node = None
420-
for node in program.graph.nodes:
421-
if node.op == "output":
422-
output_node = node
423-
break
424-
assert output_node is not None
419+
output_node = program.graph.output_node()
425420

426421
# Look for a dequantization node in the output node args. Returned values are found in the first
427422
# argument of the output node.

backends/transforms/test/test_create_delete_constant_placeholder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None):
6161
kwargs={},
6262
)
6363

64-
output_node = list(graph.nodes)[-1]
64+
output_node = graph.output_node()
6565
output_node.replace_input_with(input_node, add_node)
6666

6767
# We should now have four nodes: test_node, input, add, output

exir/backend/backend_api.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,8 @@ def _partition_and_lower_one_graph_module(
288288
tagged_graph_module, node_list, tag
289289
)
290290

291-
tagged_graph_module_output_node = [
292-
node for node in tagged_graph_module.graph.nodes if node.op == "output"
293-
][0]
294-
submodule_output_node = [
295-
node for node in submodule.graph.nodes if node.op == "output"
296-
][0]
291+
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
292+
submodule_output_node = submodule.graph.output_node()
297293
# Copy the output node meta from the original output node, because
298294
# create_submodule_from_nodes doesn't cover the meta field
299295
submodule_output_node.meta = tagged_graph_module_output_node.meta
@@ -476,12 +472,8 @@ def _create_partitions_in_graph_module(
476472
tagged_graph_module, node_list, tag
477473
)
478474

479-
tagged_graph_module_output_node = [
480-
node for node in tagged_graph_module.graph.nodes if node.op == "output"
481-
][0]
482-
submodule_output_node = [
483-
node for node in submodule.graph.nodes if node.op == "output"
484-
][0]
475+
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
476+
submodule_output_node = submodule.graph.output_node()
485477
# Copy the output node meta from the original output node, because
486478
# create_submodule_from_nodes doesn't cover the meta field
487479
submodule_output_node.meta = tagged_graph_module_output_node.meta

exir/emit/_emit_program.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ class EmitterOutput:
5757

5858
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
5959
gm = exported_program.graph_module
60-
output_node = None
61-
for node in gm.graph.nodes:
62-
if node.op == "output":
63-
output_node = node
64-
assert output_node is not None
60+
output_node = gm.graph.output_node()
6561

6662
mutated_outputs: List[Optional[str]] = [
6763
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None

exir/lowered_backend_module.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,11 @@ def program(
233233
)
234234
]
235235

236-
output_node = [
237-
node for node in lowered_exported_program.graph.nodes if node.op == "output"
238-
]
239-
assert len(output_node) == 1, "There should be only one output node"
236+
output_node = lowered_exported_program.graph.output_node()
240237

241238
# Step 1. Cleaning up the graph before inserting the call_delegate node
242239
# Remove the original output node
243-
lowered_exported_program.graph.erase_node(output_node[0])
240+
lowered_exported_program.graph.erase_node(output_node)
244241

245242
# Remove all the everything else except the input
246243
for node in reversed(lowered_exported_program.graph.nodes):
@@ -269,11 +266,9 @@ def program(
269266
)
270267
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
271268
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
272-
original_output_nodes = [
273-
node
274-
for node in self._original_exported_program.graph.nodes
275-
if node.op == "output"
276-
][0].args[0]
269+
original_output_nodes = (
270+
self._original_exported_program.graph.output_node().args[0]
271+
)
277272

278273
delegate_node.meta["spec"] = tuple(
279274
[make_spec(node.meta["val"]) for node in original_output_nodes]
@@ -927,11 +922,7 @@ def _unsafe_adjust_original_program( # noqa: C901
927922
raise RuntimeError(f"Invalid input spec {input_spec} received")
928923

929924
# Delete buffer mutations from the output which were consumed by the delegate
930-
toplevel_output_node = None
931-
for node in reversed(original_program.graph.nodes):
932-
if node.op == "output":
933-
toplevel_output_node = node
934-
break
925+
toplevel_output_node = original_program.graph.output_node()
935926

936927
assert toplevel_output_node is not None
937928
assert (

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ def _insert_copy(
3030
Find the all the buffers and inputs that were mutated and insert copy_
3131
operators to reflect mutations.
3232
"""
33-
output_node = None
34-
for node in gm.graph.nodes:
35-
if node.op == "output":
36-
output_node = node
37-
break
33+
output_node = gm.graph.output_node()
3834
assert output_node is not None
3935
outputs = pytree.tree_flatten(output_node.args)[0]
4036
assert len(outputs) == len(mutated_outputs)
@@ -139,11 +135,7 @@ def insert_write_back_for_buffers_pass(
139135
if lifted_node is not None:
140136
input_name_to_node[lifted_node] = input_node
141137

142-
output_node = None
143-
for node in gm.graph.nodes:
144-
if node.op == "output":
145-
output_node = node
146-
break
138+
output_node = gm.graph.output_node()
147139

148140
# Grab the mutable buffer nodes in the outputs,
149141
mutated_outputs: List[Optional[str]] = []

exir/passes/quantize_io_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,8 @@ def quantize_output(exported_program, output_index):
145145
output quantization.
146146
"""
147147
graph = exported_program.graph_module.graph
148-
outputs = [n for n in graph.nodes if n.op == "output"]
149-
if len(outputs) != 1:
150-
raise NotImplementedError("Only 1 output node is supported")
151148

152-
output_node = outputs[0]
149+
output_node = graph.output_node()
153150
output_list = list(output_node.args[0])
154151
if output_index >= len(output_list):
155152
raise ValueError(

exir/passes/weights_to_outputs_pass.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,7 @@ def weights_to_outputs_pass(
4646
inputs_to_params = gs.inputs_to_parameters
4747

4848
# Get output node
49-
output_node = None
50-
for node in gm.graph.nodes:
51-
if node.op == "output":
52-
output_node = node
53-
break
54-
assert output_node is not None
49+
output_node = gm.graph.output_node()
5550

5651
# Get input nodes that are weights with an associated gradient
5752
placeholder_nodes = [

exir/tests/test_joint_graph.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ def forward(self, x, y):
4242
joint_ep = _export_forward_backward(ep)
4343
edge = to_edge(joint_ep)
4444

45-
output_node = None
46-
for node in edge.exported_program().graph.nodes:
47-
if node.op == "output":
48-
output_node = node
49-
break
45+
output_node = edge.exported_program().graph.output_node()
5046

5147
orig_outputs = len(output_node.args[0])
5248

@@ -58,11 +54,7 @@ def forward(self, x, y):
5854
if spec.kind == OutputKind.TOKEN
5955
]
6056

61-
output_node = None
62-
for node in et.exported_program().graph.nodes:
63-
if node.op == "output":
64-
output_node = node
65-
break
57+
output_node = et.exported_program().graph.output_node()
6658

6759
weight_outputs = len(output_node.args[0])
6860

0 commit comments

Comments
 (0)