11# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22
3+ # pyre-strict
4+
35import logging
46import math
57import unittest
1315from executorch .exir import memory
1416from executorch .exir .dialects ._ops import ops as exir_ops
1517from executorch .exir .tests .models import MultiLayerPerceptron
18+ from executorch .exir .memory_planning import collect_specs_from_nodes
1619
1720
1821class TestMemPlanningPasses (unittest .TestCase ):
19- def test_calculate_peak_memory_pass (self ):
22+ def test_calculate_peak_memory_pass (self ) -> None :
2023 class PeakMemoryTestModel (torch .nn .Module ):
2124 def __init__ (self , input_dim : int , hidden_dim : int , output_dim : int ):
2225 super ().__init__ ()
@@ -30,7 +33,7 @@ def forward(self, x: torch.Tensor):
3033 x = self .linear2 (x )
3134 return x
3235
33- def calculate_aligned_num_bytes (num : int , alignment : int = 16 ):
36+ def calculate_aligned_num_bytes (num : int , alignment : int = 16 ) -> int :
3437 return math .ceil (num / alignment ) * alignment
3538
3639 # model 1
@@ -84,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8487 ) # Align data on a 16 byte boundary
8588 self .assertEqual (peak_usage , expected_peak_usage )
8689
87- def test_zero_memory_pass (self ):
90+ def test_zero_memory_pass (self ) -> None :
8891 class ZeroMem (torch .nn .Module ):
8992 def forward (self , x ):
9093 return x [:, 2 ::3 , ...]
@@ -186,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
186189 f"{ spec = } { arg_spec = } " ,
187190 )
188191
189- def verify_nop_memory_alloc (self , graph_module ) :
192+ def verify_nop_memory_alloc (self , graph_module : torch . fx . GraphModule ) -> None :
190193 for node in graph_module .graph .find_nodes (
191194 op = "call_function" , target = torch .ops .aten ._cat_nop .out
192195 ):
@@ -202,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
202205 ):
203206 self ._verify_select_nop_memory_alloc (node )
204207
205- def test_optimize_cat_on_placeholders (self ):
208+ def test_optimize_cat_on_placeholders (self ) -> None :
206209 class Cat (torch .nn .Module ):
207210 def forward (self , x , y ):
208211 return torch .ops .aten .cat ((x , y ))
@@ -226,7 +229,7 @@ def forward(self, x, y):
226229 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
227230 self .verify_nop_memory_alloc (graph_module )
228231
229- def test_optimize_cat_outermost (self ):
232+ def test_optimize_cat_outermost (self ) -> None :
230233 class OptimizeCatFeasible1 (torch .nn .Module ):
231234 def forward (self , x , y ):
232235 x1 = torch .add (x , 2.4 , 3.1 )
@@ -253,7 +256,7 @@ def forward(self, x, y):
253256 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
254257 self .verify_nop_memory_alloc (graph_module )
255258
256- def test_optimize_cat_non_outermost (self ):
259+ def test_optimize_cat_non_outermost (self ) -> None :
257260 class OptimizeCatFeasible2 (torch .nn .Module ):
258261 def forward (self , x , y ):
259262 x1 = torch .add (x , 2.4 , 3.1 )
@@ -280,7 +283,7 @@ def forward(self, x, y):
280283 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
281284 self .verify_nop_memory_alloc (graph_module )
282285
283- def test_no_optimize_cat_non_outermost (self ):
286+ def test_no_optimize_cat_non_outermost (self ) -> None :
284287 class OptimizeCatInfeasible1 (torch .nn .Module ):
285288 def forward (self , x , y ):
286289 x1 = torch .add (x , 2.4 , 3.1 )
@@ -306,7 +309,7 @@ def forward(self, x, y):
306309 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
307310 self .verify_nop_memory_alloc (graph_module )
308311
309- def test_no_optimize_cat_non_outermost1 (self ):
312+ def test_no_optimize_cat_non_outermost1 (self ) -> None :
310313 class OptimizeCatInfeasible2 (torch .nn .Module ):
311314 def forward (self , x , y ):
312315 x1 = torch .add (x , 2.4 , 3.1 )
@@ -333,7 +336,7 @@ def forward(self, x, y):
333336 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
334337 self .verify_nop_memory_alloc (graph_module )
335338
336- def test_optimize_cat_with_slice (self ):
339+ def test_optimize_cat_with_slice (self ) -> None :
337340 class OptimizeCatSliceFeasible (torch .nn .Module ):
338341 def forward (self , x ):
339342 x1 = torch .add (x , 2.4 , 3.1 )
@@ -362,7 +365,7 @@ def forward(self, x):
362365 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
363366 self .verify_nop_memory_alloc (graph_module )
364367
365- def test_optimize_cat_with_slice_infeasible (self ):
368+ def test_optimize_cat_with_slice_infeasible (self ) -> None :
366369 class OptimizeCatSliceInfeasible (torch .nn .Module ):
367370 def forward (self , x , y ):
368371 x1 = torch .add (x , 2.4 , 3.1 )
@@ -388,7 +391,7 @@ def forward(self, x, y):
388391 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
389392 self .verify_nop_memory_alloc (graph_module )
390393
391- def test_optimize_slice_Tensor (self ):
394+ def test_optimize_slice_Tensor (self ) -> None :
392395 class SliceTensor (torch .nn .Module ):
393396 def forward (self , x , y , z ):
394397 x1 = torch .add (x , 2.4 , 3.1 )
@@ -450,7 +453,7 @@ def forward(self, x, y, z):
450453 )
451454 self .verify_nop_memory_alloc (graph_module )
452455
453- def test_optimize_select_Tensor (self ):
456+ def test_optimize_select_Tensor (self ) -> None :
454457 class SelectTensor (torch .nn .Module ):
455458 def forward (self , x , y , z ):
456459 x1 = torch .add (x , 2.4 , 3.1 )
@@ -517,7 +520,7 @@ def forward(self, x, y, z):
517520
518521 # TODO: Test fails due to memory planning
519522 @unittest .expectedFailure
520- def test_optimize_cat_with_param (self ):
523+ def test_optimize_cat_with_param (self ) -> None :
521524 class CatWithPadding (torch .nn .Module ):
522525 def __init__ (self , padding_shape ):
523526 super ().__init__ ()
@@ -545,7 +548,7 @@ def forward(self, x, y):
545548 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
546549 self .verify_nop_memory_alloc (graph_module )
547550
548- def test_optimize_cat_then_slice_on_mutable_buffer (self ):
551+ def test_optimize_cat_then_slice_on_mutable_buffer (self ) -> None :
549552 class CatWithPadding (torch .nn .Module ):
550553 def __init__ (self , padding_shape ):
551554 super ().__init__ ()
@@ -570,7 +573,7 @@ def forward(self, x, y):
570573 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
571574 self .verify_nop_memory_alloc (graph_module )
572575
573- def test_optimize_cat_with_view (self ):
576+ def test_optimize_cat_with_view (self ) -> None :
574577 class CatViewFeasible (torch .nn .Module ):
575578 def forward (self , x , y ):
576579 x1 = torch .add (x , 2.4 , 3.1 )
@@ -597,7 +600,7 @@ def forward(self, x, y):
597600 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
598601 self .verify_nop_memory_alloc (graph_module )
599602
600- def test_no_optimize_cat_with_repeated_args (self ):
603+ def test_no_optimize_cat_with_repeated_args (self ) -> None :
601604 class CatViewInfeasible (torch .nn .Module ):
602605 def forward (self , x ):
603606 x1 = torch .add (x , 2.4 , 3.1 )
@@ -621,7 +624,7 @@ def forward(self, x):
621624 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
622625 self .verify_nop_memory_alloc (graph_module )
623626
624- def test_no_optimize_cat_with_placeholder (self ):
627+ def test_no_optimize_cat_with_placeholder (self ) -> None :
625628 class CatViewInfeasible (torch .nn .Module ):
626629 def forward (self , x , y ):
627630 # Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -739,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
739742 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
740743 self .verify_nop_memory_alloc (graph_module )
741744
742- def test_view_for_unallocated_output (self ):
745+ def test_view_for_unallocated_output (self ) -> None :
743746 class Model (torch .nn .Module ):
744747 def __init__ (self , padding_shape ):
745748 super ().__init__ ()
@@ -762,3 +765,41 @@ def forward(self, x, y):
762765 )
763766 self .assertEqual (count_node (graph_module , memory .view ), 1 )
764767 self .verify_nop_memory_alloc (graph_module )
768+
769+
770+ def test_start_alignment_constraints (self ) -> None :
771+ class Model (torch .nn .Module ):
772+ def __init__ (self ):
773+ super ().__init__ ()
774+
775+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
776+ add_0 = torch .add (x , y )
777+ add_1 = torch .add (x , add_0 )
778+ add_2 = torch .add (add_0 , add_1 )
779+ add_3 = torch .add (add_1 , add_2 )
780+ return add_3
781+
782+ model = Model ()
783+ inputs = (torch .randn (4 , 17 ), torch .randn (4 , 17 ))
784+ for mem_algo in range (0 , 2 ):
785+ graph_module = (
786+ compiler .export_to_executorch_gen_etrecord (
787+ model ,
788+ inputs ,
789+ opt_level = 1 ,
790+ mem_algo = mem_algo ,
791+ alloc_graph_input = False ,
792+ alloc_graph_output = False ,
793+ mem_alignment = 37 ,
794+ )
795+ .exported_program ()
796+ .graph_module
797+ )
798+ # Assert that all memory allocations are aligned to 32B start address
799+ for spec in collect_specs_from_nodes (
800+ graph_module .graph .nodes ,
801+ ignore_graph_input = True ,
802+ ignore_graph_output = True ,
803+ ):
804+ if spec and spec .mem_offset :
805+ self .assertEqual (spec .mem_offset % 37 , 0 )
0 commit comments