Skip to content

Commit 47b5c37

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add constraints for cadence idma ops. (#12597)
Summary: Add memory planning constraints for idma ops: 1. idma load: output needs to be in DTCM 2. idma store: input needs to be in DTCM 3. idma wait: output aliases the input Reviewed By: zonglinpeng Differential Revision: D77232760
1 parent c9df2aa commit 47b5c37

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

backends/cadence/aot/memory_constraints.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
654654
]
655655

656656

657+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
658+
class GenerateIdmaConstraints(PassBase):
659+
"""Generate constraints for idma ops."""
660+
661+
def __init__(self, constraint: MemConstraints) -> None:
662+
self.constraint = constraint
663+
664+
def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
665+
for node in graph_module.graph.find_nodes(
666+
op="call_function", target=torch.ops.cadence.idma_wait.out
667+
):
668+
# This is just an alias op.
669+
self.constraint.add_relative_placement_constraint(node.args[0], node)
670+
671+
for node in graph_module.graph.find_nodes(
672+
op="call_function", target=torch.ops.cadence.idma_load.out
673+
):
674+
# TODO: set correct dtcm bank here.
675+
mem_id = 1
676+
self.constraint.add_absolute_placement_constraint(node, mem_id, None)
677+
678+
for node in graph_module.graph.find_nodes(
679+
op="call_function", target=torch.ops.cadence.idma_store.out
680+
):
681+
# TODO: set correct dtcm bank here.
682+
mem_id = 1
683+
self.constraint.add_absolute_placement_constraint(
684+
node.args[0], mem_id, None
685+
)
686+
687+
657688
# The class to generate all the constraints that will be passed on to the memory
658689
# planning algorithm.
659690
class GenerateMemConstraints:
@@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
671702
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
672703
list[ConstraintsGenPass],
673704
[
705+
GenerateIdmaConstraints,
674706
GenerateMemoryViewConstraints,
675707
GenerateSliceAndSelectNopConstraints,
676708
GenerateCatNopConstraints,

backends/cadence/aot/memory_planning.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def run(
420420
# True.
421421
mem_planning = MemoryPlanningPass(
422422
self.algo,
423-
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
423+
# Always allow lifetime and storage overlap.
424+
# At opt level 0, we need overlap for idma wait.
425+
allow_lifetime_and_storage_overlap=True,
424426
alloc_graph_input=self.alloc_graph_input,
425427
alloc_graph_output=self.alloc_graph_output,
426428
)

backends/cadence/aot/ops_registrations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,13 @@
178178
# Post memory planning, we check that outputs/inputs for the load/store are in
179179
# DTCM and replace idma_load/idma_store with idma_copy.
180180
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
181+
lib.define(
182+
"idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
183+
)
181184
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")
185+
lib.define(
186+
"idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
187+
)
182188

183189
# Non-blocking iDMA copy.
184190
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")

0 commit comments

Comments
 (0)