diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 663c5825e52..755692ec2ec 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -9,7 +9,7 @@ import logging from dataclasses import dataclass, field -from typing import cast, List, Optional, Sequence, Set +from typing import cast, List, Optional, Sequence, Set, Type import torch import torch.fx @@ -926,19 +926,25 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return super().call(graph_module) +class CommonRemovePasses: + passes: List[Type[ExportPass]] = [ + RemoveCloneOpPass, + RemoveAliasCopyOpPass, + RemoveNopExpandOpPass, + RemoveNopSliceOrViewOpPass, + RemoveNopSelectOpPass, + RemoveToOpsPass, + RemoveZeroSizedCatArgsPass, + ] + + class CadenceRemoveNops: - passes = [ + passes: List[Type[ExportPass]] = CommonRemovePasses.passes + [ SimplifySliceOpPass, RemoveCloneOpsTransformImported, - RemoveToOpsPass, RemoveNopRequantizeOpPass, - RemoveZeroSizedCatArgsPass, - RemoveNopSliceOrViewOpPass, - RemoveNopExpandOpPass, RemoveZeroSizedConstantPadNd, - RemoveCloneOpPass, RemoveContiguousOpPass, - RemoveAliasCopyOpPass, RemoveNopMulOpPass, RemoveNopAddOpPass, RemoveNopLinalgVectorNormOpPass,