@@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
654654]
655655
656656
657+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
658+ class GenerateIdmaConstraints (PassBase ):
659+ """Generate constraints for idma ops."""
660+
661+ def __init__ (self , constraint : MemConstraints ) -> None :
662+ self .constraint = constraint
663+
664+ def call (self , graph_module : torch .fx .GraphModule ) -> Optional [PassResult ]:
665+ for node in graph_module .graph .find_nodes (
666+ op = "call_function" , target = torch .ops .cadence .idma_wait .out
667+ ):
668+ # This is just an alias op.
669+ self .constraint .add_relative_placement_constraint (node .args [0 ], node )
670+
671+ for node in graph_module .graph .find_nodes (
672+ op = "call_function" , target = torch .ops .cadence .idma_load .out
673+ ):
674+ # TODO: set correct dtcm bank here.
675+ mem_id = 1
676+ self .constraint .add_absolute_placement_constraint (node , mem_id , None )
677+
678+ for node in graph_module .graph .find_nodes (
679+ op = "call_function" , target = torch .ops .cadence .idma_store .out
680+ ):
681+ # TODO: set correct dtcm bank here.
682+ mem_id = 1
683+ self .constraint .add_absolute_placement_constraint (
684+ node .args [0 ], mem_id , None
685+ )
686+
687+
657688# The class to generate all the constraints that will be passed on to the memory
658689# planning algorithm.
659690class GenerateMemConstraints :
@@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
671702 constraint_gen_passes : Sequence [ConstraintsGenPass ] = cast (
672703 list [ConstraintsGenPass ],
673704 [
705+ GenerateIdmaConstraints ,
674706 GenerateMemoryViewConstraints ,
675707 GenerateSliceAndSelectNopConstraints ,
676708 GenerateCatNopConstraints ,
0 commit comments