1414from executorch .backends .cadence .aot .pass_utils import count_node
1515from executorch .exir import memory
1616from executorch .exir .dialects ._ops import ops as exir_ops
17+ from executorch .exir .memory_planning import collect_specs_from_nodes
1718from executorch .exir .tests .models import MultiLayerPerceptron
1819
1920
2021class TestMemPlanningPasses (unittest .TestCase ):
21- def test_calculate_peak_memory_pass (self ):
22+ def test_calculate_peak_memory_pass (self ) -> None :
2223 class PeakMemoryTestModel (torch .nn .Module ):
2324 def __init__ (self , input_dim : int , hidden_dim : int , output_dim : int ):
2425 super ().__init__ ()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
3233 x = self .linear2 (x )
3334 return x
3435
35- def calculate_aligned_num_bytes (num : int , alignment : int = 16 ):
36+ def calculate_aligned_num_bytes (num : int , alignment : int = 16 ) -> int :
3637 return math .ceil (num / alignment ) * alignment
3738
3839 # model 1
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8687 ) # Align data on a 16 byte boundary
8788 self .assertEqual (peak_usage , expected_peak_usage )
8889
89- def test_zero_memory_pass (self ):
90+ def test_zero_memory_pass (self ) -> None :
9091 class ZeroMem (torch .nn .Module ):
9192 def forward (self , x ):
9293 return x [:, 2 ::3 , ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188189 f"{ spec = } { arg_spec = } " ,
189190 )
190191
191- def verify_nop_memory_alloc (self , graph_module ) :
192+ def verify_nop_memory_alloc (self , graph_module : torch . fx . GraphModule ) -> None :
192193 for node in graph_module .graph .find_nodes (
193194 op = "call_function" , target = torch .ops .aten ._cat_nop .out
194195 ):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204205 ):
205206 self ._verify_select_nop_memory_alloc (node )
206207
207- def test_optimize_cat_on_placeholders (self ):
208+ def test_optimize_cat_on_placeholders (self ) -> None :
208209 class Cat (torch .nn .Module ):
209210 def forward (self , x , y ):
210211 return torch .ops .aten .cat ((x , y ))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228229 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
229230 self .verify_nop_memory_alloc (graph_module )
230231
231- def test_optimize_cat_outermost (self ):
232+ def test_optimize_cat_outermost (self ) -> None :
232233 class OptimizeCatFeasible1 (torch .nn .Module ):
233234 def forward (self , x , y ):
234235 x1 = torch .add (x , 2.4 , 3.1 )
@@ -255,7 +256,7 @@ def forward(self, x, y):
255256 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
256257 self .verify_nop_memory_alloc (graph_module )
257258
258- def test_optimize_cat_non_outermost (self ):
259+ def test_optimize_cat_non_outermost (self ) -> None :
259260 class OptimizeCatFeasible2 (torch .nn .Module ):
260261 def forward (self , x , y ):
261262 x1 = torch .add (x , 2.4 , 3.1 )
@@ -282,7 +283,7 @@ def forward(self, x, y):
282283 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
283284 self .verify_nop_memory_alloc (graph_module )
284285
285- def test_no_optimize_cat_non_outermost (self ):
286+ def test_no_optimize_cat_non_outermost (self ) -> None :
286287 class OptimizeCatInfeasible1 (torch .nn .Module ):
287288 def forward (self , x , y ):
288289 x1 = torch .add (x , 2.4 , 3.1 )
@@ -308,7 +309,7 @@ def forward(self, x, y):
308309 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
309310 self .verify_nop_memory_alloc (graph_module )
310311
311- def test_no_optimize_cat_non_outermost1 (self ):
312+ def test_no_optimize_cat_non_outermost1 (self ) -> None :
312313 class OptimizeCatInfeasible2 (torch .nn .Module ):
313314 def forward (self , x , y ):
314315 x1 = torch .add (x , 2.4 , 3.1 )
@@ -335,7 +336,7 @@ def forward(self, x, y):
335336 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
336337 self .verify_nop_memory_alloc (graph_module )
337338
338- def test_optimize_cat_with_slice (self ):
339+ def test_optimize_cat_with_slice (self ) -> None :
339340 class OptimizeCatSliceFeasible (torch .nn .Module ):
340341 def forward (self , x ):
341342 x1 = torch .add (x , 2.4 , 3.1 )
@@ -364,7 +365,7 @@ def forward(self, x):
364365 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
365366 self .verify_nop_memory_alloc (graph_module )
366367
367- def test_optimize_cat_with_slice_infeasible (self ):
368+ def test_optimize_cat_with_slice_infeasible (self ) -> None :
368369 class OptimizeCatSliceInfeasible (torch .nn .Module ):
369370 def forward (self , x , y ):
370371 x1 = torch .add (x , 2.4 , 3.1 )
@@ -390,7 +391,7 @@ def forward(self, x, y):
390391 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
391392 self .verify_nop_memory_alloc (graph_module )
392393
393- def test_optimize_slice_Tensor (self ):
394+ def test_optimize_slice_Tensor (self ) -> None :
394395 class SliceTensor (torch .nn .Module ):
395396 def forward (self , x , y , z ):
396397 x1 = torch .add (x , 2.4 , 3.1 )
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452453 )
453454 self .verify_nop_memory_alloc (graph_module )
454455
455- def test_optimize_select_Tensor (self ):
456+ def test_optimize_select_Tensor (self ) -> None :
456457 class SelectTensor (torch .nn .Module ):
457458 def forward (self , x , y , z ):
458459 x1 = torch .add (x , 2.4 , 3.1 )
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519520
520521 # TODO: Test fails due to memory planning
521522 @unittest .expectedFailure
522- def test_optimize_cat_with_param (self ):
523+ def test_optimize_cat_with_param (self ) -> None :
523524 class CatWithPadding (torch .nn .Module ):
524525 def __init__ (self , padding_shape ):
525526 super ().__init__ ()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547548 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
548549 self .verify_nop_memory_alloc (graph_module )
549550
550- def test_optimize_cat_then_slice_on_mutable_buffer (self ):
551+ def test_optimize_cat_then_slice_on_mutable_buffer (self ) -> None :
551552 class CatWithPadding (torch .nn .Module ):
552553 def __init__ (self , padding_shape ):
553554 super ().__init__ ()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572573 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
573574 self .verify_nop_memory_alloc (graph_module )
574575
575- def test_optimize_cat_with_view (self ):
576+ def test_optimize_cat_with_view (self ) -> None :
576577 class CatViewFeasible (torch .nn .Module ):
577578 def forward (self , x , y ):
578579 x1 = torch .add (x , 2.4 , 3.1 )
@@ -599,7 +600,7 @@ def forward(self, x, y):
599600 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
600601 self .verify_nop_memory_alloc (graph_module )
601602
602- def test_no_optimize_cat_with_repeated_args (self ):
603+ def test_no_optimize_cat_with_repeated_args (self ) -> None :
603604 class CatViewInfeasible (torch .nn .Module ):
604605 def forward (self , x ):
605606 x1 = torch .add (x , 2.4 , 3.1 )
@@ -623,7 +624,7 @@ def forward(self, x):
623624 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
624625 self .verify_nop_memory_alloc (graph_module )
625626
626- def test_no_optimize_cat_with_placeholder (self ):
627+ def test_no_optimize_cat_with_placeholder (self ) -> None :
627628 class CatViewInfeasible (torch .nn .Module ):
628629 def forward (self , x , y ):
629630 # Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741742 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
742743 self .verify_nop_memory_alloc (graph_module )
743744
744- def test_view_for_unallocated_output (self ):
745+ def test_view_for_unallocated_output (self ) -> None :
745746 class Model (torch .nn .Module ):
746747 def __init__ (self , padding_shape ):
747748 super ().__init__ ()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764765 )
765766 self .assertEqual (count_node (graph_module , memory .view ), 1 )
766767 self .verify_nop_memory_alloc (graph_module )
768+
769+ def test_start_alignment_constraints (self ) -> None :
770+ class Model (torch .nn .Module ):
771+ def __init__ (self ):
772+ super ().__init__ ()
773+
774+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
775+ add_0 = torch .add (x , y )
776+ add_1 = torch .add (x , add_0 )
777+ add_2 = torch .add (add_0 , add_1 )
778+ add_3 = torch .add (add_1 , add_2 )
779+ return add_3
780+
781+ model = Model ()
782+ inputs = (torch .randn (4 , 17 ), torch .randn (4 , 17 ))
783+ for mem_algo in range (0 , 2 ):
784+ graph_module = (
785+ compiler .export_to_executorch_gen_etrecord (
786+ model ,
787+ inputs ,
788+ opt_level = 1 ,
789+ mem_algo = mem_algo ,
790+ alloc_graph_input = False ,
791+ alloc_graph_output = False ,
792+ mem_alignment = 37 ,
793+ )
794+ .exported_program ()
795+ .graph_module
796+ )
797+ # Assert that all memory allocations are aligned to 32B start address
798+ for spec in collect_specs_from_nodes (
799+ graph_module .graph .nodes ,
800+ ignore_graph_input = True ,
801+ ignore_graph_output = True ,
802+ ):
803+ if spec and spec .mem_offset :
804+ self .assertEqual (spec .mem_offset % 37 , 0 )
0 commit comments