|
9 | 9 | import collections |
10 | 10 | import itertools |
11 | 11 | import logging |
12 | | -from typing import Iterable, Optional, Sequence |
| 12 | +from typing import Callable, Iterable, Optional, Sequence, TypeAlias |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.backends.cadence.aot.memory_constraints import MemConstraints |
|
26 | 26 |
|
27 | 27 | from executorch.exir import ExecutorchProgramManager |
28 | 28 | from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier |
| 29 | +from executorch.exir.pass_base import PassBase |
| 30 | +from executorch.exir.pass_manager import PassManager |
29 | 31 | from executorch.exir.passes import MemoryPlanningPass |
30 | 32 | from executorch.exir.tensor import TensorSpec |
31 | 33 | from tabulate import tabulate |
@@ -359,6 +361,35 @@ def print_memory_planning_info( |
359 | 361 | ) |
360 | 362 |
|
361 | 363 |
|
| 364 | +class SimplifyIdmaOpsPass(PassBase): |
| 365 | + """Replace idma_load and idma_store with idma_copy.""" |
| 366 | + |
| 367 | + def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: |
| 368 | + modified = False |
| 369 | + for node in graph_module.graph.find_nodes( |
| 370 | + op="call_function", target=torch.ops.cadence.idma_load.out |
| 371 | + ): |
| 372 | + modified = True |
| 373 | + node.target = torch.ops.cadence.idma_copy.out |
| 374 | + node.args = (node.args[0], *node.args[2:]) |
| 375 | + |
| 376 | + for node in graph_module.graph.find_nodes( |
| 377 | + op="call_function", target=torch.ops.cadence.idma_store.out |
| 378 | + ): |
| 379 | + modified = True |
| 380 | + node.target = torch.ops.cadence.idma_copy.out |
| 381 | + |
| 382 | + graph_module.graph.eliminate_dead_code() |
| 383 | + graph_module.recompile() |
| 384 | + return PassResult(graph_module, modified) |
| 385 | + |
| 386 | + |
| 387 | +ConstraintGenPassType: TypeAlias = Callable[ |
| 388 | + [MemConstraints], |
| 389 | + Callable[[torch.fx.GraphModule], Optional[PassResult]], |
| 390 | +] |
| 391 | + |
| 392 | + |
362 | 393 | class CadenceMemoryPlanning: |
363 | 394 | def __init__( |
364 | 395 | self, |
@@ -423,10 +454,16 @@ def run( |
423 | 454 | # True. |
424 | 455 | mem_planning = MemoryPlanningPass( |
425 | 456 | self.algo, |
426 | | - allow_lifetime_and_storage_overlap=(self.opt_level >= 2), |
| 457 | + # Always allow lifetime and storage overlap. |
| 458 | + # At opt level 0, we need overlap for idma wait. |
| 459 | + allow_lifetime_and_storage_overlap=True, |
427 | 460 | alloc_graph_input=self.alloc_graph_input, |
428 | 461 | alloc_graph_output=self.alloc_graph_output, |
429 | 462 | ) |
430 | 463 | mem_planning.run(graph_module, graph_signature) |
431 | 464 |
|
| 465 | + graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])( |
| 466 | + graph_module |
| 467 | + ).graph_module |
| 468 | + |
432 | 469 | return PassResult(graph_module, True) |
0 commit comments