@@ -295,6 +295,37 @@ def create_constant_nodes_and_return_specs(
295295 return name_to_spec_dict
296296
297297
298+ def _update_output_node_and_specs (exported_program : ExportedProgram ) -> None :
299+ """
300+ Update the output node and output specs in the exported program.
301+ In case a constant node is used as output, we replace it with a clone of the constant node.
302+ """
303+ # Dict [node.name -> InputSpec]
304+ updated_constant_placeholders = get_constant_placeholder_dict (exported_program )
305+ output = exported_program .graph .find_nodes (op = "output" )[0 ]
306+ output_nodes = cast (list [torch .fx .Node ], list (output .args [0 ]))
307+ output_specs = exported_program .graph_signature .output_specs
308+ assert len (output_nodes ) == len (output_specs )
309+
310+ for i in range (len (output_specs )):
311+ out_node = output_nodes [i ]
312+ if out_node not in updated_constant_placeholders :
313+ continue
314+
315+ with exported_program .graph .inserting_after (out_node ):
316+ new_node = exported_program .graph .call_function (
317+ exir_ops .edge .aten .clone .default , (out_node ,)
318+ )
319+ assert "val" in out_node .meta
320+ new_node .meta ["val" ] = out_node .meta ["val" ]
321+ output_nodes [i ] = new_node
322+
323+ # Update the constant-propagated output node.
324+ output_specs [i ].arg = TensorArgument (name = output_nodes [i ].name )
325+
326+ output .args = (output_nodes ,)
327+
328+
298329def constant_prop_pass (
299330 exported_program : ExportedProgram ,
300331 custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
@@ -341,12 +372,12 @@ def constant_prop_pass(
341372
342373 # Generate new input spec.
343374 new_input_specs = []
344- for node in exported_program .graph .nodes :
345- if node .op != "placeholder" :
346- continue
375+ for node in exported_program .graph .find_nodes (op = "placeholder" ):
347376 new_input_specs .append (name_to_spec_dict [node .name ])
348377 exported_program .graph_signature .input_specs = new_input_specs
349378
379+ _update_output_node_and_specs (exported_program )
380+
350381 # Cleanup the graph.
351382 exported_program .graph .eliminate_dead_code ()
352383 exported_program .graph_module .recompile ()
0 commit comments