Skip to content

Commit a99b64f

Browse files
authored
Post memory planning passes.
Differential Revision: D77232764 Pull Request resolved: #13918
1 parent ae862f7 commit a99b64f

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

backends/cadence/aot/memory_planning.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12-
from typing import Iterable, Optional, Sequence
12+
from typing import Callable, Iterable, Optional, Sequence, TypeAlias
1313

1414
import torch
1515
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
@@ -26,6 +26,8 @@
2626

2727
from executorch.exir import ExecutorchProgramManager
2828
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
29+
from executorch.exir.pass_base import PassBase
30+
from executorch.exir.pass_manager import PassManager
2931
from executorch.exir.passes import MemoryPlanningPass
3032
from executorch.exir.tensor import TensorSpec
3133
from tabulate import tabulate
@@ -359,6 +361,35 @@ def print_memory_planning_info(
359361
)
360362

361363

364+
class SimplifyIdmaOpsPass(PassBase):
365+
"""Replace idma_load and idma_store with idma_copy."""
366+
367+
def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
368+
modified = False
369+
for node in graph_module.graph.find_nodes(
370+
op="call_function", target=torch.ops.cadence.idma_load.out
371+
):
372+
modified = True
373+
node.target = torch.ops.cadence.idma_copy.out
374+
node.args = (node.args[0], *node.args[2:])
375+
376+
for node in graph_module.graph.find_nodes(
377+
op="call_function", target=torch.ops.cadence.idma_store.out
378+
):
379+
modified = True
380+
node.target = torch.ops.cadence.idma_copy.out
381+
382+
graph_module.graph.eliminate_dead_code()
383+
graph_module.recompile()
384+
return PassResult(graph_module, modified)
385+
386+
387+
ConstraintGenPassType: TypeAlias = Callable[
388+
[MemConstraints],
389+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
390+
]
391+
392+
362393
class CadenceMemoryPlanning:
363394
def __init__(
364395
self,
@@ -431,4 +462,8 @@ def run(
431462
)
432463
mem_planning.run(graph_module, graph_signature)
433464

465+
graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])(
466+
graph_module
467+
).graph_module
468+
434469
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)