11# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22
3+ import logging
34import math
45import unittest
6+ from typing import cast
57
68import executorch .backends .cadence .aot .ops_registrations # noqa
79import torch
@@ -110,7 +112,119 @@ def forward(self, x):
110112
111113
112114class TestMemTransform (unittest .TestCase ):
113- def test_optimize_cat (self ):
115+ def verify_cat_nop_memory_alloc (self , cat_node : torch .fx .Node ) -> None :
116+ spec = cat_node .meta .get ("spec" , None )
117+ self .assertIsNotNone (spec )
118+ dim : int = cast (int , cat_node .args [1 ]) if len (cat_node .args ) > 1 else 0
119+ cat_outer_size = math .prod (spec .shape [:dim ])
120+ self .assertEqual (
121+ cat_outer_size ,
122+ 1 ,
123+ f"{ cat_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
124+ )
125+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126+ dim_offset = 0
127+ for arg in cast (list [torch .fx .Node ], cat_node .args [0 ]):
128+ arg_spec = arg .meta .get ("spec" , None )
129+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130+ self .assertEqual (
131+ arg_spec .mem_offset ,
132+ spec .mem_offset + dim_offset * inner_dim_elements ,
133+ f"{ arg = } for node { cat_node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134+ )
135+ dim_offset += arg_spec .shape [dim ]
136+
137+ def verify_slice_nop_memory_alloc (self , slice_node : torch .fx .Node ) -> None :
138+ spec = slice_node .meta .get ("spec" , None )
139+ self .assertIsNotNone (spec )
140+ dim : int = cast (int , slice_node .args [1 ]) if len (slice_node .args ) > 1 else 0
141+ cat_outer_size = math .prod (spec .shape [:dim ])
142+ self .assertEqual (
143+ cat_outer_size ,
144+ 1 ,
145+ f"{ slice_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
146+ )
147+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148+ start : int = (
149+ cast (int , slice_node .args [2 ])
150+ if (len (slice_node .args ) > 2 and slice_node .args [2 ] is not None )
151+ else 0
152+ )
153+ arg = cast (torch .fx .Node , slice_node .args [0 ])
154+ arg_spec = arg .meta .get ("spec" , None )
155+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156+ self .assertEqual (
157+ spec .mem_offset ,
158+ arg_spec .mem_offset + start * inner_dim_elements ,
159+ f"{ arg = } for node { slice_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
160+ )
161+
162+ def verify_select_nop_memory_alloc (self , select_node : torch .fx .Node ) -> None :
163+ spec = select_node .meta .get ("spec" , None )
164+ self .assertIsNotNone (spec )
165+ dim : int = cast (int , select_node .args [1 ]) if len (select_node .args ) > 1 else 0
166+ cat_outer_size = math .prod (spec .shape [:dim ])
167+ self .assertEqual (
168+ cat_outer_size ,
169+ 1 ,
170+ f"{ select_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
171+ )
172+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
173+ index : int = (
174+ cast (int , select_node .args [2 ])
175+ if (len (select_node .args ) > 2 and select_node .args [2 ] is not None )
176+ else 0
177+ )
178+ arg = cast (torch .fx .Node , select_node .args [0 ])
179+ arg_spec = arg .meta .get ("spec" , None )
180+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
181+ self .assertEqual (
182+ spec .mem_offset ,
183+ arg_spec .mem_offset + index * inner_dim_elements ,
184+ f"{ arg = } for node { select_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
185+ )
186+
187+ def verify_nop_memory_alloc (self , graph_module ):
188+ for cat_node in graph_module .graph .find_nodes (
189+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
190+ ):
191+ self .verify_cat_nop_memory_alloc (cat_node )
192+
193+ for slice_node in graph_module .graph .find_nodes (
194+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
195+ ):
196+ self .verify_slice_nop_memory_alloc (slice_node )
197+
198+ for select_node in graph_module .graph .find_nodes (
199+ op = "call_function" , target = torch .ops .aten ._select_copy_nop .Tensor_out
200+ ):
201+ self .verify_select_nop_memory_alloc (slice_node )
202+
203+ def test_optimize_cat_on_placeholders (self ):
204+ class Cat (torch .nn .Module ):
205+ def forward (self , x , y ):
206+ return torch .ops .aten .cat ((x , y ))
207+
208+ x = torch .ones (3 , 6 )
209+ y = torch .ones (2 , 6 )
210+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
211+ # pass to run:
212+ graph_module = (
213+ compiler .export_to_executorch_gen_etrecord (
214+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
215+ )
216+ .exported_program ()
217+ .graph_module
218+ )
219+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
220+ graph_module .graph .eliminate_dead_code ()
221+ # Assert that cat op is optimized away
222+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
223+ # Assert that cat op is replaced by its nop version post optimization
224+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
225+ self .verify_nop_memory_alloc (graph_module )
226+
227+ def test_optimize_cat_outermost (self ):
114228 class OptimizeCatFeasible1 (torch .nn .Module ):
115229 def forward (self , x , y ):
116230 x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +249,9 @@ def forward(self, x, y):
135249 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136250 # Assert that cat op is replaced by its nop version post optimization
137251 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
252+ self .verify_nop_memory_alloc (graph_module )
138253
254+ def test_optimize_cat_non_outermost (self ):
139255 class OptimizeCatFeasible2 (torch .nn .Module ):
140256 def forward (self , x , y ):
141257 x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +276,9 @@ def forward(self, x, y):
160276 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161277 # Assert that cat op is replaced by its nop version post optimization
162278 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
279+ self .verify_nop_memory_alloc (graph_module )
163280
281+ def test_no_optimize_cat_non_outermost (self ):
164282 class OptimizeCatInfeasible1 (torch .nn .Module ):
165283 def forward (self , x , y ):
166284 x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +302,9 @@ def forward(self, x, y):
184302 # Assert that cat op is not optimized away, since the concat is not
185303 # along the outermost dim
186304 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
305+ self .verify_nop_memory_alloc (graph_module )
187306
307+ def test_no_optimize_cat_non_outermost1 (self ):
188308 class OptimizeCatInfeasible2 (torch .nn .Module ):
189309 def forward (self , x , y ):
190310 x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +329,7 @@ def forward(self, x, y):
209329 # offsets are not multiple of 8 bytes, and the cat is not the output
210330 # of the graph.
211331 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
332+ self .verify_nop_memory_alloc (graph_module )
212333
213334 def test_optimize_cat_with_slice (self ):
214335 class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +358,7 @@ def forward(self, x):
237358 graph_module .graph .eliminate_dead_code ()
238359 # Assert that cat op is optimized away
239360 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
361+ self .verify_nop_memory_alloc (graph_module )
240362
241363 def test_optimize_cat_with_slice_infeasible (self ):
242364 class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +384,7 @@ def forward(self, x, y):
262384 graph_module .graph .eliminate_dead_code ()
263385 # Assert that cat op is not optimized away
264386 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
387+ self .verify_nop_memory_alloc (graph_module )
265388
266389 def test_optimize_slice_Tensor (self ):
267390 class SliceTensor (torch .nn .Module ):
@@ -323,6 +446,7 @@ def forward(self, x, y, z):
323446 self .assertEqual (
324447 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325448 )
449+ self .verify_nop_memory_alloc (graph_module )
326450
327451 def test_optimize_select_Tensor (self ):
328452 class SelectTensor (torch .nn .Module ):
@@ -387,6 +511,7 @@ def forward(self, x, y, z):
387511 self .assertEqual (
388512 count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389513 )
514+ self .verify_nop_memory_alloc (graph_module )
390515
391516 # TODO: Test fails due to memory planning
392517 @unittest .expectedFailure
@@ -416,6 +541,32 @@ def forward(self, x, y):
416541 graph_module .graph .eliminate_dead_code ()
417542 # Assert that cat op is not optimized away
418543 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
544+ self .verify_nop_memory_alloc (graph_module )
545+
546+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
547+ class CatWithPadding (torch .nn .Module ):
548+ def __init__ (self , padding_shape ):
549+ super ().__init__ ()
550+ zeros = torch .zeros (padding_shape )
551+ self .register_buffer ("padding" , zeros )
552+
553+ def forward (self , x , y ):
554+ x = x .view (3 , 5 )
555+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
556+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
557+ self .padding .copy_ (slice_copy )
558+ return cat .view (- 1 ) + y
559+
560+ x = torch .ones (15 )
561+ y = torch .ones (1 )
562+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
563+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
564+ )
565+ graph_module = et_prog_manager .exported_program ().graph_module
566+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
567+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
568+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
569+ self .verify_nop_memory_alloc (graph_module )
419570
420571 def test_optimize_cat_with_view (self ):
421572 class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +593,7 @@ def forward(self, x, y):
442593 # Assert that cat op is optimized away
443594 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444595 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
596+ self .verify_nop_memory_alloc (graph_module )
445597
446598 def test_no_optimize_cat_with_repeated_args (self ):
447599 class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +617,7 @@ def forward(self, x):
465617 # Assert that cat op is not optimized away
466618 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467619 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
620+ self .verify_nop_memory_alloc (graph_module )
468621
469622 def test_no_optimize_cat_with_placeholder (self ):
470623 class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +645,7 @@ def forward(self, x, y):
492645 # Assert that cat op is not optimized away
493646 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494647 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
648+ self .verify_nop_memory_alloc (graph_module )
495649
496650 def test_no_optimize_cat (self ) -> None :
497651 class Model (torch .nn .Module ):
@@ -522,6 +676,7 @@ def forward(self, x) -> torch.Tensor:
522676 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523677 )
524678 self .assertEqual (count_node (graph_module , memory .view ), 2 )
679+ self .verify_nop_memory_alloc (graph_module )
525680
526681 def test_optimize_slice_copy (self ) -> None :
527682 class Model (torch .nn .Module ):
@@ -553,6 +708,7 @@ def forward(self, x) -> torch.Tensor:
553708 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554709 )
555710 self .assertEqual (count_node (graph_module , memory .view ), 2 )
711+ self .verify_nop_memory_alloc (graph_module )
556712
557713 def test_cat_then_cat (self ) -> None :
558714 class Model (torch .nn .Module ):
@@ -579,6 +735,7 @@ def forward(self, x) -> torch.Tensor:
579735 graph_module .print_readable ()
580736 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581737 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
738+ self .verify_nop_memory_alloc (graph_module )
582739
583740 def test_view_for_unallocated_output (self ):
584741 class Model (torch .nn .Module ):
@@ -602,3 +759,4 @@ def forward(self, x, y):
602759 .graph_module
603760 )
604761 self .assertEqual (count_node (graph_module , memory .view ), 1 )
762+ self .verify_nop_memory_alloc (graph_module )
0 commit comments