@@ -295,7 +295,11 @@ def create_constant_nodes_and_return_specs(
295
295
return name_to_spec_dict
296
296
297
297
298
- def _update_output_node_and_specs (exported_program : ExportedProgram ) -> None :
298
+ # add _skip_dim_order to ensure the introduced correct clone node for different dim order schema
299
+ # TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig
300
+ def _update_output_node_and_specs (
301
+ exported_program : ExportedProgram , _skip_dim_order : bool
302
+ ) -> None :
299
303
"""
300
304
Update the output node and output specs in the exported program.
301
305
In case a constant node is used as output, we replace it with a clone of the constant node.
@@ -307,15 +311,19 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
307
311
output_specs = exported_program .graph_signature .output_specs
308
312
assert len (output_nodes ) == len (output_specs )
309
313
314
+ clone_op = (
315
+ exir_ops .edge .aten .clone .default
316
+ if _skip_dim_order
317
+ else exir_ops .edge .dim_order_ops ._clone_dim_order .default
318
+ )
319
+
310
320
for i in range (len (output_specs )):
311
321
out_node = output_nodes [i ]
312
322
if out_node not in updated_constant_placeholders :
313
323
continue
314
324
315
325
with exported_program .graph .inserting_after (out_node ):
316
- new_node = exported_program .graph .call_function (
317
- exir_ops .edge .aten .clone .default , (out_node ,)
318
- )
326
+ new_node = exported_program .graph .call_function (clone_op , (out_node ,))
319
327
assert "val" in out_node .meta
320
328
new_node .meta ["val" ] = out_node .meta ["val" ]
321
329
output_nodes [i ] = new_node
@@ -329,6 +337,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
329
337
def constant_prop_pass (
330
338
exported_program : ExportedProgram ,
331
339
custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
340
+ _skip_dim_order : bool = True ,
332
341
) -> ExportedProgram :
333
342
"""
334
343
This pass is for constant propagation for Exported Program with lifted parameters,
@@ -376,7 +385,7 @@ def constant_prop_pass(
376
385
new_input_specs .append (name_to_spec_dict [node .name ])
377
386
exported_program .graph_signature .input_specs = new_input_specs
378
387
379
- _update_output_node_and_specs (exported_program )
388
+ _update_output_node_and_specs (exported_program , _skip_dim_order = _skip_dim_order )
380
389
381
390
# Cleanup the graph.
382
391
exported_program .graph .eliminate_dead_code ()
0 commit comments