@@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
654
654
]
655
655
656
656
657
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
658
+ class GenerateIdmaConstraints (PassBase ):
659
+ """Generate constraints for idma ops."""
660
+
661
+ def __init__ (self , constraint : MemConstraints ) -> None :
662
+ self .constraint = constraint
663
+
664
+ def call (self , graph_module : torch .fx .GraphModule ) -> Optional [PassResult ]:
665
+ for node in graph_module .graph .find_nodes (
666
+ op = "call_function" , target = torch .ops .cadence .idma_wait .out
667
+ ):
668
+ # This is just an alias op.
669
+ self .constraint .add_relative_placement_constraint (node .args [0 ], node )
670
+
671
+ for node in graph_module .graph .find_nodes (
672
+ op = "call_function" , target = torch .ops .cadence .idma_load .out
673
+ ):
674
+ # TODO: set correct dtcm bank here.
675
+ mem_id = 1
676
+ self .constraint .add_absolute_placement_constraint (node , mem_id , None )
677
+
678
+ for node in graph_module .graph .find_nodes (
679
+ op = "call_function" , target = torch .ops .cadence .idma_store .out
680
+ ):
681
+ # TODO: set correct dtcm bank here.
682
+ mem_id = 1
683
+ self .constraint .add_absolute_placement_constraint (
684
+ node .args [0 ], mem_id , None
685
+ )
686
+
687
+
657
688
# The class to generate all the constraints that will be passed on to the memory
658
689
# planning algorithm.
659
690
class GenerateMemConstraints :
@@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
671
702
constraint_gen_passes : Sequence [ConstraintsGenPass ] = cast (
672
703
list [ConstraintsGenPass ],
673
704
[
705
+ GenerateIdmaConstraints ,
674
706
GenerateMemoryViewConstraints ,
675
707
GenerateSliceAndSelectNopConstraints ,
676
708
GenerateCatNopConstraints ,
0 commit comments