@@ -233,14 +233,11 @@ def program(
233
233
)
234
234
]
235
235
236
- output_node = [
237
- node for node in lowered_exported_program .graph .nodes if node .op == "output"
238
- ]
239
- assert len (output_node ) == 1 , "There should be only one output node"
236
+ output_node = lowered_exported_program .graph .output_node ()
240
237
241
238
# Step 1. Cleaning up the graph before inserting the call_delegate node
242
239
# Remove the original output node
243
- lowered_exported_program .graph .erase_node (output_node [ 0 ] )
240
+ lowered_exported_program .graph .erase_node (output_node )
244
241
245
242
# Remove all the everything else except the input
246
243
for node in reversed (lowered_exported_program .graph .nodes ):
@@ -269,11 +266,9 @@ def program(
269
266
)
270
267
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
271
268
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
272
- original_output_nodes = [
273
- node
274
- for node in self ._original_exported_program .graph .nodes
275
- if node .op == "output"
276
- ][0 ].args [0 ]
269
+ original_output_nodes = (
270
+ self ._original_exported_program .graph .output_node ().args [0 ]
271
+ )
277
272
278
273
delegate_node .meta ["spec" ] = tuple (
279
274
[make_spec (node .meta ["val" ]) for node in original_output_nodes ]
@@ -927,11 +922,7 @@ def _unsafe_adjust_original_program( # noqa: C901
927
922
raise RuntimeError (f"Invalid input spec { input_spec } received" )
928
923
929
924
# Delete buffer mutations from the output which were consumed by the delegate
930
- toplevel_output_node = None
931
- for node in reversed (original_program .graph .nodes ):
932
- if node .op == "output" :
933
- toplevel_output_node = node
934
- break
925
+ toplevel_output_node = original_program .graph .output_node ()
935
926
936
927
assert toplevel_output_node is not None
937
928
assert (
0 commit comments