@@ -295,7 +295,11 @@ def create_constant_nodes_and_return_specs(
295295 return name_to_spec_dict
296296
297297
298- def _update_output_node_and_specs (exported_program : ExportedProgram ) -> None :
298+ # add _skip_dim_order to ensure the introduced correct clone node for different dim order schema
299+ # TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig
300+ def _update_output_node_and_specs (
301+ exported_program : ExportedProgram , _skip_dim_order : bool
302+ ) -> None :
299303 """
300304 Update the output node and output specs in the exported program.
301305 In case a constant node is used as output, we replace it with a clone of the constant node.
@@ -307,15 +311,19 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
307311 output_specs = exported_program .graph_signature .output_specs
308312 assert len (output_nodes ) == len (output_specs )
309313
314+ clone_op = (
315+ exir_ops .edge .aten .clone .default
316+ if _skip_dim_order
317+ else exir_ops .edge .dim_order_ops ._clone_dim_order .default
318+ )
319+
310320 for i in range (len (output_specs )):
311321 out_node = output_nodes [i ]
312322 if out_node not in updated_constant_placeholders :
313323 continue
314324
315325 with exported_program .graph .inserting_after (out_node ):
316- new_node = exported_program .graph .call_function (
317- exir_ops .edge .aten .clone .default , (out_node ,)
318- )
326+ new_node = exported_program .graph .call_function (clone_op , (out_node ,))
319327 assert "val" in out_node .meta
320328 new_node .meta ["val" ] = out_node .meta ["val" ]
321329 output_nodes [i ] = new_node
@@ -329,6 +337,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
329337def constant_prop_pass (
330338 exported_program : ExportedProgram ,
331339 custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
340+ _skip_dim_order : bool = True ,
332341) -> ExportedProgram :
333342 """
334343 This pass is for constant propagation for Exported Program with lifted parameters,
@@ -376,7 +385,7 @@ def constant_prop_pass(
376385 new_input_specs .append (name_to_spec_dict [node .name ])
377386 exported_program .graph_signature .input_specs = new_input_specs
378387
379- _update_output_node_and_specs (exported_program )
388+ _update_output_node_and_specs (exported_program , _skip_dim_order = _skip_dim_order )
380389
381390 # Cleanup the graph.
382391 exported_program .graph .eliminate_dead_code ()
0 commit comments