11# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22
3+ import logging
34import math
45import unittest
6+ from typing import cast
57
68import executorch .backends .cadence .aot .ops_registrations # noqa
79import torch
@@ -110,7 +112,89 @@ def forward(self, x):
110112
111113
112114class TestMemTransform (unittest .TestCase ):
113- def test_optimize_cat (self ):
115+ def verify_cat_nop_memory_alloc (self , cat_node : torch .fx .Node ) -> None :
116+ spec = cat_node .meta .get ("spec" , None )
117+ self .assertIsNotNone (spec )
118+ dim : int = cast (int , cat_node .args [1 ]) if len (cat_node .args ) > 1 else 0
119+ cat_outer_size = math .prod (spec .shape [:dim ])
120+ self .assertEqual (
121+ cat_outer_size ,
122+ 1 ,
123+ f"{ cat_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
124+ )
125+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126+ dim_offset = 0
127+ for arg in cast (list [torch .fx .Node ], cat_node .args [0 ]):
128+ arg_spec = arg .meta .get ("spec" , None )
129+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130+ self .assertEqual (
131+ arg_spec .mem_offset ,
132+ spec .mem_offset + dim_offset * inner_dim_elements ,
133+ f"{ arg = } for node { cat_node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134+ )
135+ dim_offset += arg_spec .shape [dim ]
136+
137+ def verify_slice_nop_memory_alloc (self , slice_node : torch .fx .Node ) -> None :
138+ spec = slice_node .meta .get ("spec" , None )
139+ self .assertIsNotNone (spec )
140+ dim : int = cast (int , slice_node .args [1 ]) if len (slice_node .args ) > 1 else 0
141+ cat_outer_size = math .prod (spec .shape [:dim ])
142+ self .assertEqual (
143+ cat_outer_size ,
144+ 1 ,
145+ f"{ slice_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
146+ )
147+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148+ start : int = (
149+ cast (int , slice_node .args [2 ])
150+ if (len (slice_node .args ) > 2 and slice_node .args [2 ] is not None )
151+ else 0
152+ )
153+ arg = cast (torch .fx .Node , slice_node .args [0 ])
154+ arg_spec = arg .meta .get ("spec" , None )
155+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156+ self .assertEqual (
157+ spec .mem_offset ,
158+ arg_spec .mem_offset + start * inner_dim_elements ,
159+ f"{ arg = } for node { slice_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
160+ )
161+
162+ def verify_nop_memory_alloc (self , graph_module ):
163+ for cat_node in graph_module .graph .find_nodes (
164+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
165+ ):
166+ self .verify_cat_nop_memory_alloc (cat_node )
167+
168+ for slice_node in graph_module .graph .find_nodes (
169+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
170+ ):
171+ self .verify_slice_nop_memory_alloc (slice_node )
172+
173+ def test_optimize_cat_on_placeholders (self ):
174+ class Cat (torch .nn .Module ):
175+ def forward (self , x , y ):
176+ return torch .ops .aten .cat ((x , y ))
177+
178+ x = torch .ones (3 , 6 )
179+ y = torch .ones (2 , 6 )
180+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
181+ # pass to run:
182+ graph_module = (
183+ compiler .export_to_executorch_gen_etrecord (
184+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
185+ )
186+ .exported_program ()
187+ .graph_module
188+ )
189+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
190+ graph_module .graph .eliminate_dead_code ()
191+ # Assert that cat op is optimized away
192+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
193+ # Assert that cat op is replaced by its nop version post optimization
194+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
195+ self .verify_nop_memory_alloc (graph_module )
196+
197+ def test_optimize_cat_outermost (self ):
114198 class OptimizeCatFeasible1 (torch .nn .Module ):
115199 def forward (self , x , y ):
116200 x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +219,9 @@ def forward(self, x, y):
135219 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136220 # Assert that cat op is replaced by its nop version post optimization
137221 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
222+ self .verify_nop_memory_alloc (graph_module )
138223
224+ def test_optimize_cat_non_outermost (self ):
139225 class OptimizeCatFeasible2 (torch .nn .Module ):
140226 def forward (self , x , y ):
141227 x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +246,9 @@ def forward(self, x, y):
160246 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161247 # Assert that cat op is replaced by its nop version post optimization
162248 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
249+ self .verify_nop_memory_alloc (graph_module )
163250
251+ def test_no_optimize_cat_non_outermost (self ):
164252 class OptimizeCatInfeasible1 (torch .nn .Module ):
165253 def forward (self , x , y ):
166254 x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +272,9 @@ def forward(self, x, y):
184272 # Assert that cat op is not optimized away, since the concat is not
185273 # along the outermost dim
186274 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
275+ self .verify_nop_memory_alloc (graph_module )
187276
277+ def test_no_optimize_cat_non_outermost1 (self ):
188278 class OptimizeCatInfeasible2 (torch .nn .Module ):
189279 def forward (self , x , y ):
190280 x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +299,7 @@ def forward(self, x, y):
209299 # offsets are not multiple of 8 bytes, and the cat is not the output
210300 # of the graph.
211301 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
302+ self .verify_nop_memory_alloc (graph_module )
212303
213304 def test_optimize_cat_with_slice (self ):
214305 class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +328,7 @@ def forward(self, x):
237328 graph_module .graph .eliminate_dead_code ()
238329 # Assert that cat op is optimized away
239330 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
331+ self .verify_nop_memory_alloc (graph_module )
240332
241333 def test_optimize_cat_with_slice_infeasible (self ):
242334 class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +354,7 @@ def forward(self, x, y):
262354 graph_module .graph .eliminate_dead_code ()
263355 # Assert that cat op is not optimized away
264356 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
357+ self .verify_nop_memory_alloc (graph_module )
265358
266359 def test_optimize_slice_Tensor (self ):
267360 class SliceTensor (torch .nn .Module ):
@@ -323,6 +416,7 @@ def forward(self, x, y, z):
323416 self .assertEqual (
324417 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325418 )
419+ self .verify_nop_memory_alloc (graph_module )
326420
327421 def test_optimize_select_Tensor (self ):
328422 class SelectTensor (torch .nn .Module ):
@@ -387,6 +481,7 @@ def forward(self, x, y, z):
387481 self .assertEqual (
388482 count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389483 )
484+ self .verify_nop_memory_alloc (graph_module )
390485
391486 # TODO: Test fails due to memory planning
392487 @unittest .expectedFailure
@@ -416,6 +511,32 @@ def forward(self, x, y):
416511 graph_module .graph .eliminate_dead_code ()
417512 # Assert that cat op is not optimized away
418513 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
514+ self .verify_nop_memory_alloc (graph_module )
515+
516+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
517+ class CatWithPadding (torch .nn .Module ):
518+ def __init__ (self , padding_shape ):
519+ super ().__init__ ()
520+ zeros = torch .zeros (padding_shape )
521+ self .register_buffer ("padding" , zeros )
522+
523+ def forward (self , x , y ):
524+ x = x .view (3 , 5 )
525+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
526+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
527+ self .padding .copy_ (slice_copy )
528+ return cat .view (- 1 ) + y
529+
530+ x = torch .ones (15 )
531+ y = torch .ones (1 )
532+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
533+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
534+ )
535+ graph_module = et_prog_manager .exported_program ().graph_module
536+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
537+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
538+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
539+ self .verify_nop_memory_alloc (graph_module )
419540
420541 def test_optimize_cat_with_view (self ):
421542 class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +563,7 @@ def forward(self, x, y):
442563 # Assert that cat op is optimized away
443564 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444565 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
566+ self .verify_nop_memory_alloc (graph_module )
445567
446568 def test_no_optimize_cat_with_repeated_args (self ):
447569 class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +587,7 @@ def forward(self, x):
465587 # Assert that cat op is not optimized away
466588 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467589 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
590+ self .verify_nop_memory_alloc (graph_module )
468591
469592 def test_no_optimize_cat_with_placeholder (self ):
470593 class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +615,7 @@ def forward(self, x, y):
492615 # Assert that cat op is not optimized away
493616 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494617 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
618+ self .verify_nop_memory_alloc (graph_module )
495619
496620 def test_no_optimize_cat (self ) -> None :
497621 class Model (torch .nn .Module ):
@@ -522,6 +646,7 @@ def forward(self, x) -> torch.Tensor:
522646 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523647 )
524648 self .assertEqual (count_node (graph_module , memory .view ), 2 )
649+ self .verify_nop_memory_alloc (graph_module )
525650
526651 def test_optimize_slice_copy (self ) -> None :
527652 class Model (torch .nn .Module ):
@@ -553,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
553678 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554679 )
555680 self .assertEqual (count_node (graph_module , memory .view ), 2 )
681+ self .verify_nop_memory_alloc (graph_module )
556682
557683 def test_cat_then_cat (self ) -> None :
558684 class Model (torch .nn .Module ):
@@ -579,6 +705,7 @@ def forward(self, x) -> torch.Tensor:
579705 graph_module .print_readable ()
580706 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581707 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
708+ self .verify_nop_memory_alloc (graph_module )
582709
583710 def test_view_for_unallocated_output (self ):
584711 class Model (torch .nn .Module ):
@@ -602,3 +729,4 @@ def forward(self, x, y):
602729 .graph_module
603730 )
604731 self .assertEqual (count_node (graph_module , memory .view ), 1 )
732+ self .verify_nop_memory_alloc (graph_module )
0 commit comments