|
19 | 19 |
|
20 | 20 | import logging |
21 | 21 | from dataclasses import dataclass, field |
22 | | -from typing import cast, List, Optional, Sequence, Set |
| 22 | +from typing import cast, List, Optional, Sequence, Set, Type |
23 | 23 |
|
24 | 24 | import torch |
25 | 25 | import torch.fx |
@@ -940,21 +940,28 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
940 | 940 | # in Jarvis. Currently, each function in this class iterates over each node of |
941 | 941 | # the graph module once. In future, we could consolidate them into a monolithic |
942 | 942 | # function. |
943 | | -class CadenceRemoveNops: |
944 | | - passes = [ |
| 943 | +class GenericRemoveNops: |
| 944 | + passes: List[Type[ExportPass]] = [ |
945 | 945 | SimplifySliceOpPass, |
946 | 946 | RemoveCloneOpsTransformImported, |
947 | 947 | RemoveToOpsPass, |
948 | 948 | RemoveNopRequantizeOpPass, |
949 | 949 | RemoveZeroSizedCatArgsPass, |
950 | 950 | RemoveNopSliceOrViewOpPass, |
951 | 951 | RemoveNopExpandOpPass, |
952 | | - RemoveZeroSizedConstantPadNd, |
953 | 952 | RemoveCloneOpPass, |
954 | 953 | RemoveContiguousOpPass, |
955 | 954 | RemoveAliasCopyOpPass, |
956 | 955 | RemoveNopMulOpPass, |
957 | 956 | RemoveNopAddOpPass, |
958 | 957 | RemoveNopLinalgVectorNormOpPass, |
959 | 958 | RemoveBranchedQuantDequant, |
| 959 | + RemoveNopSelectOpPass, |
| 960 | + RemovePermutesAroundElementwiseOps, |
| 961 | + RemoveSqueezeViewBeforeElementwiseOps, |
| 962 | + RemoveCatFromSliceCopyPass, |
960 | 963 | ] |
| 964 | + |
| 965 | +class CadenceRemoveNops: |
| 966 | + passes: List[Type[ExportPass]] = GenericRemoveNops.passes |
| 967 | + passes += [RemoveZeroSizedConstantPadNd] |
0 commit comments