diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 98102a415e5..f7419ff25dc 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -38,7 +38,7 @@ class RemoveCloneOpsTransformImported(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: finalize_passes: List[PassType] = [ - RemoveCloneOpsTransform(), + RemoveCloneOpsTransform(eliminate_quant_dequant_pairs=False), ] result = PassManager(passes=finalize_passes)(graph_module) dead_code_elimination_pass(result.graph_module) @@ -356,20 +356,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveCloneOpPass(RemoveOrReplacePassInterface): - # If the op is a clone op, return the input and eliminate the op - @property - def targets(self) -> list[EdgeOpOverload]: - return [exir_ops.edge.aten.clone.default] - - def maybe_remove_or_replace(self, node: Node) -> bool: - input_node = node.args[0] - assert isinstance(input_node, Node) - node.replace_all_uses_with(input_node) - return True - - @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveContiguousOpPass(RemoveOrReplacePassInterface): """ @@ -925,7 +911,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool: class CommonRemovePasses: passes: List[Type[ExportPass]] = [ - RemoveCloneOpPass, RemoveAliasCopyOpPass, RemoveNopExpandOpPass, RemoveNopSliceOrViewOpPass, @@ -934,13 +919,13 @@ class CommonRemovePasses: RemovePermutesAroundElementwiseOps, RemoveSqueezeViewBeforeElementwiseOps, RemoveCatFromSliceCopyPass, + RemoveCloneOpsTransformImported, ] class CadenceRemoveNops: passes: List[Type[ExportPass]] = CommonRemovePasses.passes + [ SimplifySliceOpPass, - RemoveCloneOpsTransformImported, RemoveNopRequantizeOpPass, RemoveZeroSizedConstantPadNd, RemoveContiguousOpPass,