|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | - |
10 | 9 | import logging |
| 10 | +from abc import abstractmethod |
11 | 11 | from dataclasses import dataclass, field |
12 | 12 | from typing import cast, List, Optional, Sequence, Set, Type |
13 | 13 |
|
|
35 | 35 | from torch.fx.node import Argument, Node |
36 | 36 |
|
37 | 37 |
|
| 38 | +class RemovePassCommon(ExportPass): |
| 39 | + @property |
| 40 | + @abstractmethod |
| 41 | + def targets(self) -> list[EdgeOpOverload]: |
| 42 | + """ |
| 43 | + The list of targets to potentially remove. |
| 44 | + """ |
| 45 | + raise NotImplementedError("`targets` must be defined") |
| 46 | + |
| 47 | + @abstractmethod |
| 48 | + def maybe_remove(self, node: Node) -> bool: |
| 49 | + """ |
| 50 | + If the node should be removed, removes from the graph. Returns |
| 51 | + True if the graph was modified, else False. |
| 52 | + """ |
| 53 | + raise NotImplementedError("`maybe_remove` must be defined") |
| 54 | + |
| 55 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 56 | + """ |
| 57 | + For each node in targets, if the node should be removed, removes from |
| 58 | + the graph and returns the modified graph and modified set to True. |
| 59 | + If no node should be removed, returns a pass result with the original |
| 60 | + graph module and False for modified. |
| 61 | + """ |
| 62 | + changed = False |
| 63 | + for target in self.targets: |
| 64 | + for module in graph_module.modules(): |
| 65 | + assert isinstance(module, torch.fx.GraphModule) |
| 66 | + for node in module.graph.find_nodes(op="call_function", target=target): |
| 67 | + changed |= self.maybe_remove(node) |
| 68 | + |
| 69 | + if changed: |
| 70 | + graph_module.graph.eliminate_dead_code() |
| 71 | + graph_module.recompile() |
| 72 | + return super().call(graph_module) |
| 73 | + |
| 74 | + return PassResult(graph_module, False) |
| 75 | + |
| 76 | + |
38 | 77 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
39 | 78 | class RemoveCloneOpsTransformImported(ExportPass): |
40 | 79 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
@@ -186,33 +225,27 @@ def call_operator( |
186 | 225 |
|
187 | 226 |
|
188 | 227 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
189 | | -class RemoveNopSliceOrViewOpPass(ExportPass): |
| 228 | +class RemoveNopSliceOrViewOpPass(RemovePassCommon): |
190 | 229 | """ |
191 | 230 | Remove slice ops that are more like views, and view ops that do not change the shape |
192 | 231 | """ |
193 | 232 |
|
194 | | - def call_operator( |
195 | | - self, |
196 | | - op, # pyre-ignore |
197 | | - args: tuple[Argument, ...], |
198 | | - kwargs: dict[str, Argument], |
199 | | - meta: NodeMetadata, |
200 | | - ) -> ProxyValue: |
201 | | - if op not in { |
| 233 | + @property |
| 234 | + def targets(self) -> list[EdgeOpOverload]: |
| 235 | + return [ |
202 | 236 | exir_ops.edge.aten.slice_copy.Tensor, |
203 | 237 | exir_ops.edge.aten.view_copy.default, |
204 | | - }: |
205 | | - return super().call_operator(op, args, kwargs, meta) |
| 238 | + ] |
206 | 239 |
|
207 | | - arg0 = cast(ProxyValue, args[0]) |
208 | | - out_shape = meta["val"].shape |
| 240 | + def maybe_remove(self, node: Node) -> bool: |
| 241 | + changed = False |
| 242 | + input_node = node.args[0] |
| 243 | + assert isinstance(input_node, Node) |
| 244 | + if input_node.meta["val"].shape == node.meta["val"].shape: |
| 245 | + node.replace_all_uses_with(input_node) |
| 246 | + changed = True |
209 | 247 |
|
210 | | - # If both arg_shape and out_shape are the same, this slice is a nop |
211 | | - return ( |
212 | | - arg0 |
213 | | - if arg0.to_tensor().shape == out_shape |
214 | | - else super().call_operator(op, args, kwargs, meta) |
215 | | - ) |
| 248 | + return changed |
216 | 249 |
|
217 | 250 |
|
218 | 251 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
|
0 commit comments