|
9 | 9 | import collections
|
10 | 10 | import itertools
|
11 | 11 | import logging
|
12 |
| -from typing import Iterable, Optional, Sequence |
| 12 | +from typing import Callable, Iterable, Optional, Sequence, TypeAlias |
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | from executorch.backends.cadence.aot.memory_constraints import MemConstraints
|
|
26 | 26 |
|
27 | 27 | from executorch.exir import ExecutorchProgramManager
|
28 | 28 | from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
|
| 29 | +from executorch.exir.pass_base import PassBase |
| 30 | +from executorch.exir.pass_manager import PassManager |
29 | 31 | from executorch.exir.passes import MemoryPlanningPass
|
30 | 32 | from executorch.exir.tensor import TensorSpec
|
31 | 33 | from tabulate import tabulate
|
@@ -359,6 +361,35 @@ def print_memory_planning_info(
|
359 | 361 | )
|
360 | 362 |
|
361 | 363 |
|
| 364 | +class SimplifyIdmaOpsPass(PassBase): |
| 365 | + """Replace idma_load and idma_store with idma_copy.""" |
| 366 | + |
| 367 | + def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: |
| 368 | + modified = False |
| 369 | + for node in graph_module.graph.find_nodes( |
| 370 | + op="call_function", target=torch.ops.cadence.idma_load.out |
| 371 | + ): |
| 372 | + modified = True |
| 373 | + node.target = torch.ops.cadence.idma_copy.out |
| 374 | + node.args = (node.args[0], *node.args[2:]) |
| 375 | + |
| 376 | + for node in graph_module.graph.find_nodes( |
| 377 | + op="call_function", target=torch.ops.cadence.idma_store.out |
| 378 | + ): |
| 379 | + modified = True |
| 380 | + node.target = torch.ops.cadence.idma_copy.out |
| 381 | + |
| 382 | + graph_module.graph.eliminate_dead_code() |
| 383 | + graph_module.recompile() |
| 384 | + return PassResult(graph_module, modified) |
| 385 | + |
| 386 | + |
| 387 | +ConstraintGenPassType: TypeAlias = Callable[ |
| 388 | + [MemConstraints], |
| 389 | + Callable[[torch.fx.GraphModule], Optional[PassResult]], |
| 390 | +] |
| 391 | + |
| 392 | + |
362 | 393 | class CadenceMemoryPlanning:
|
363 | 394 | def __init__(
|
364 | 395 | self,
|
@@ -431,4 +462,8 @@ def run(
|
431 | 462 | )
|
432 | 463 | mem_planning.run(graph_module, graph_signature)
|
433 | 464 |
|
| 465 | + graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])( |
| 466 | + graph_module |
| 467 | + ).graph_module |
| 468 | + |
434 | 469 | return PassResult(graph_module, True)
|
0 commit comments