1414from  executorch .backends .cadence .aot .pass_utils  import  count_node 
1515from  executorch .exir  import  memory 
1616from  executorch .exir .dialects ._ops  import  ops  as  exir_ops 
17+ from  executorch .exir .memory_planning  import  collect_specs_from_nodes 
1718from  executorch .exir .tests .models  import  MultiLayerPerceptron 
1819
1920
2021class  TestMemPlanningPasses (unittest .TestCase ):
21-     def  test_calculate_peak_memory_pass (self ):
22+     def  test_calculate_peak_memory_pass (self )  ->   None :
2223        class  PeakMemoryTestModel (torch .nn .Module ):
2324            def  __init__ (self , input_dim : int , hidden_dim : int , output_dim : int ):
2425                super ().__init__ ()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
3233                x  =  self .linear2 (x )
3334                return  x 
3435
35-         def  calculate_aligned_num_bytes (num : int , alignment : int  =  16 ):
36+         def  calculate_aligned_num_bytes (num : int , alignment : int  =  16 )  ->   int :
3637            return  math .ceil (num  /  alignment ) *  alignment 
3738
3839        # model 1 
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8687        )  # Align data on a 16 byte boundary 
8788        self .assertEqual (peak_usage , expected_peak_usage )
8889
89-     def  test_zero_memory_pass (self ):
90+     def  test_zero_memory_pass (self )  ->   None :
9091        class  ZeroMem (torch .nn .Module ):
9192            def  forward (self , x ):
9293                return  x [:, 2 ::3 , ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188189            f"{ spec = }   { arg_spec = }  " ,
189190        )
190191
191-     def  verify_nop_memory_alloc (self , graph_module ) :
192+     def  verify_nop_memory_alloc (self , graph_module :  torch . fx . GraphModule )  ->   None :
192193        for  node  in  graph_module .graph .find_nodes (
193194            op = "call_function" , target = torch .ops .aten ._cat_nop .out 
194195        ):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204205        ):
205206            self ._verify_select_nop_memory_alloc (node )
206207
207-     def  test_optimize_cat_on_placeholders (self ):
208+     def  test_optimize_cat_on_placeholders (self )  ->   None :
208209        class  Cat (torch .nn .Module ):
209210            def  forward (self , x , y ):
210211                return  torch .ops .aten .cat ((x , y ))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228229        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
229230        self .verify_nop_memory_alloc (graph_module )
230231
231-     def  test_optimize_cat_outermost (self ):
232+     def  test_optimize_cat_outermost (self )  ->   None :
232233        class  OptimizeCatFeasible1 (torch .nn .Module ):
233234            def  forward (self , x , y ):
234235                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -255,7 +256,7 @@ def forward(self, x, y):
255256        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
256257        self .verify_nop_memory_alloc (graph_module )
257258
258-     def  test_optimize_cat_non_outermost (self ):
259+     def  test_optimize_cat_non_outermost (self )  ->   None :
259260        class  OptimizeCatFeasible2 (torch .nn .Module ):
260261            def  forward (self , x , y ):
261262                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -282,7 +283,7 @@ def forward(self, x, y):
282283        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
283284        self .verify_nop_memory_alloc (graph_module )
284285
285-     def  test_no_optimize_cat_non_outermost (self ):
286+     def  test_no_optimize_cat_non_outermost (self )  ->   None :
286287        class  OptimizeCatInfeasible1 (torch .nn .Module ):
287288            def  forward (self , x , y ):
288289                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -308,7 +309,7 @@ def forward(self, x, y):
308309        self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
309310        self .verify_nop_memory_alloc (graph_module )
310311
311-     def  test_no_optimize_cat_non_outermost1 (self ):
312+     def  test_no_optimize_cat_non_outermost1 (self )  ->   None :
312313        class  OptimizeCatInfeasible2 (torch .nn .Module ):
313314            def  forward (self , x , y ):
314315                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -335,7 +336,7 @@ def forward(self, x, y):
335336        self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
336337        self .verify_nop_memory_alloc (graph_module )
337338
338-     def  test_optimize_cat_with_slice (self ):
339+     def  test_optimize_cat_with_slice (self )  ->   None :
339340        class  OptimizeCatSliceFeasible (torch .nn .Module ):
340341            def  forward (self , x ):
341342                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -364,7 +365,7 @@ def forward(self, x):
364365        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
365366        self .verify_nop_memory_alloc (graph_module )
366367
367-     def  test_optimize_cat_with_slice_infeasible (self ):
368+     def  test_optimize_cat_with_slice_infeasible (self )  ->   None :
368369        class  OptimizeCatSliceInfeasible (torch .nn .Module ):
369370            def  forward (self , x , y ):
370371                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -390,7 +391,7 @@ def forward(self, x, y):
390391        self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
391392        self .verify_nop_memory_alloc (graph_module )
392393
393-     def  test_optimize_slice_Tensor (self ):
394+     def  test_optimize_slice_Tensor (self )  ->   None :
394395        class  SliceTensor (torch .nn .Module ):
395396            def  forward (self , x , y , z ):
396397                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452453        )
453454        self .verify_nop_memory_alloc (graph_module )
454455
455-     def  test_optimize_select_Tensor (self ):
456+     def  test_optimize_select_Tensor (self )  ->   None :
456457        class  SelectTensor (torch .nn .Module ):
457458            def  forward (self , x , y , z ):
458459                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519520
520521    # TODO: Test fails due to memory planning 
521522    @unittest .expectedFailure  
522-     def  test_optimize_cat_with_param (self ):
523+     def  test_optimize_cat_with_param (self )  ->   None :
523524        class  CatWithPadding (torch .nn .Module ):
524525            def  __init__ (self , padding_shape ):
525526                super ().__init__ ()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547548        self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
548549        self .verify_nop_memory_alloc (graph_module )
549550
550-     def  test_optimize_cat_then_slice_on_mutable_buffer (self ):
551+     def  test_optimize_cat_then_slice_on_mutable_buffer (self )  ->   None :
551552        class  CatWithPadding (torch .nn .Module ):
552553            def  __init__ (self , padding_shape ):
553554                super ().__init__ ()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572573        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
573574        self .verify_nop_memory_alloc (graph_module )
574575
575-     def  test_optimize_cat_with_view (self ):
576+     def  test_optimize_cat_with_view (self )  ->   None :
576577        class  CatViewFeasible (torch .nn .Module ):
577578            def  forward (self , x , y ):
578579                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -599,7 +600,7 @@ def forward(self, x, y):
599600        self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
600601        self .verify_nop_memory_alloc (graph_module )
601602
602-     def  test_no_optimize_cat_with_repeated_args (self ):
603+     def  test_no_optimize_cat_with_repeated_args (self )  ->   None :
603604        class  CatViewInfeasible (torch .nn .Module ):
604605            def  forward (self , x ):
605606                x1  =  torch .add (x , 2.4 , 3.1 )
@@ -623,7 +624,7 @@ def forward(self, x):
623624        self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
624625        self .verify_nop_memory_alloc (graph_module )
625626
626-     def  test_no_optimize_cat_with_placeholder (self ):
627+     def  test_no_optimize_cat_with_placeholder (self )  ->   None :
627628        class  CatViewInfeasible (torch .nn .Module ):
628629            def  forward (self , x , y ):
629630                # Repeat will be decomposed into a cat. The cat cannot be optimized 
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741742        self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
742743        self .verify_nop_memory_alloc (graph_module )
743744
744-     def  test_view_for_unallocated_output (self ):
745+     def  test_view_for_unallocated_output (self )  ->   None :
745746        class  Model (torch .nn .Module ):
746747            def  __init__ (self , padding_shape ):
747748                super ().__init__ ()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764765        )
765766        self .assertEqual (count_node (graph_module , memory .view ), 1 )
766767        self .verify_nop_memory_alloc (graph_module )
768+ 
769+     def  test_start_alignment_constraints (self ) ->  None :
770+         class  Model (torch .nn .Module ):
771+             def  __init__ (self ):
772+                 super ().__init__ ()
773+ 
774+             def  forward (self , x : torch .Tensor , y : torch .Tensor ):
775+                 add_0  =  torch .add (x , y )
776+                 add_1  =  torch .add (x , add_0 )
777+                 add_2  =  torch .add (add_0 , add_1 )
778+                 add_3  =  torch .add (add_1 , add_2 )
779+                 return  add_3 
780+ 
781+         model  =  Model ()
782+         inputs  =  (torch .randn (4 , 17 ), torch .randn (4 , 17 ))
783+         for  mem_algo  in  range (0 , 2 ):
784+             graph_module  =  (
785+                 compiler .export_to_executorch_gen_etrecord (
786+                     model ,
787+                     inputs ,
788+                     opt_level = 1 ,
789+                     mem_algo = mem_algo ,
790+                     alloc_graph_input = False ,
791+                     alloc_graph_output = False ,
792+                     mem_alignment = 37 ,
793+                 )
794+                 .exported_program ()
795+                 .graph_module 
796+             )
797+             # Assert that all memory allocations are aligned to 32B start address 
798+             for  spec  in  collect_specs_from_nodes (
799+                 graph_module .graph .nodes ,
800+                 ignore_graph_input = True ,
801+                 ignore_graph_output = True ,
802+             ):
803+                 if  spec  and  spec .mem_offset :
804+                     self .assertEqual (spec .mem_offset  %  37 , 0 )
0 commit comments