@@ -233,14 +233,11 @@ def program(
233233 )
234234 ]
235235
236- output_node = [
237- node for node in lowered_exported_program .graph .nodes if node .op == "output"
238- ]
239- assert len (output_node ) == 1 , "There should be only one output node"
236+ output_node = lowered_exported_program .graph .output_node ()
240237
241238 # Step 1. Cleaning up the graph before inserting the call_delegate node
242239 # Remove the original output node
243- lowered_exported_program .graph .erase_node (output_node [ 0 ] )
240+ lowered_exported_program .graph .erase_node (output_node )
244241
245242 # Remove all the everything else except the input
246243 for node in reversed (lowered_exported_program .graph .nodes ):
@@ -269,11 +266,9 @@ def program(
269266 )
270267 # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
271268 # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
272- original_output_nodes = [
273- node
274- for node in self ._original_exported_program .graph .nodes
275- if node .op == "output"
276- ][0 ].args [0 ]
269+ original_output_nodes = (
270+ self ._original_exported_program .graph .output_node ().args [0 ]
271+ )
277272
278273 delegate_node .meta ["spec" ] = tuple (
279274 [make_spec (node .meta ["val" ]) for node in original_output_nodes ]
@@ -927,11 +922,7 @@ def _unsafe_adjust_original_program( # noqa: C901
927922 raise RuntimeError (f"Invalid input spec { input_spec } received" )
928923
929924 # Delete buffer mutations from the output which were consumed by the delegate
930- toplevel_output_node = None
931- for node in reversed (original_program .graph .nodes ):
932- if node .op == "output" :
933- toplevel_output_node = node
934- break
925+ toplevel_output_node = original_program .graph .output_node ()
935926
936927 assert toplevel_output_node is not None
937928 assert (
0 commit comments