@@ -947,6 +947,110 @@ def test_cat_then_cat(self) -> None:
947947 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
948948 self .verify_nop_memory_alloc (graph_module )
949949
950+ def test_cat_with_duplicate_input_tensor (self ) -> None :
951+ """
952+ Test that cat is NOT optimized when the same tensor appears multiple
953+ times in the cat input list. This is because we cannot place the same
954+ tensor at multiple different offsets relative to the output.
955+ """
956+ builder = GraphBuilder ()
957+ x = builder .placeholder ("x" , torch .ones (3 , 6 , dtype = torch .float32 ))
958+ to_add_to_x = builder .call_operator (
959+ op = exir_ops .edge .aten .full .default ,
960+ args = ([3 , 6 ], 123.0 ),
961+ kwargs = {"dtype" : torch .float32 },
962+ )
963+ add_x = builder .call_operator (
964+ op = exir_ops .edge .aten .add .Tensor ,
965+ args = (x , to_add_to_x ),
966+ )
967+ pre_created_output = builder .call_operator (
968+ op = exir_ops .edge .aten .full .default ,
969+ args = ([6 , 6 ], 0.0 ),
970+ kwargs = {"dtype" : torch .float32 },
971+ )
972+ # Same tensor (add_x) appears twice in the cat inputs
973+ cat = builder .call_operator (
974+ op = torch .ops .aten .cat .out ,
975+ args = ([add_x , add_x ],),
976+ kwargs = {"dim" : 0 , "out" : pre_created_output },
977+ )
978+ builder .output ([cat ])
979+ original = builder .get_graph_module ()
980+ graph_module = self .run_memory_planning (original )
981+ graph_module .graph .eliminate_dead_code ()
982+
983+ # Assert that cat op is NOT optimized away since the same tensor
984+ # appears multiple times in the input list
985+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
986+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
987+ self .verify_nop_memory_alloc (graph_module )
988+
989+ def test_cat_with_tensor_having_existing_constraint (self ) -> None :
990+ """
991+ Test that the second cat is NOT optimized when a tensor already has a
992+ relative placement constraint from a previous cat operation.
993+ """
994+ builder = GraphBuilder ()
995+ x = builder .placeholder ("x" , torch .ones (8 , 8 , dtype = torch .float32 ))
996+ to_add = builder .call_operator (
997+ op = exir_ops .edge .aten .full .default ,
998+ args = ([8 , 8 ], 1.0 ),
999+ kwargs = {"dtype" : torch .float32 },
1000+ )
1001+ x1 = builder .call_operator (
1002+ op = exir_ops .edge .aten .add .Tensor ,
1003+ args = (x , to_add ),
1004+ )
1005+ x2 = builder .call_operator (
1006+ op = exir_ops .edge .aten .add .Tensor ,
1007+ args = (x1 , to_add ),
1008+ )
1009+ x3 = builder .call_operator (
1010+ op = exir_ops .edge .aten .add .Tensor ,
1011+ args = (x2 , to_add ),
1012+ )
1013+ # First cat: cat(x1, x2) - this will give x1 and x2 relative placement constraints
1014+ pre_created_output1 = builder .call_operator (
1015+ op = exir_ops .edge .aten .full .default ,
1016+ args = ([16 , 8 ], 0.0 ),
1017+ kwargs = {"dtype" : torch .float32 },
1018+ )
1019+ cat1 = builder .call_operator (
1020+ op = torch .ops .aten .cat .out ,
1021+ args = ([x1 , x2 ],),
1022+ kwargs = {"dim" : 0 , "out" : pre_created_output1 },
1023+ )
1024+ # Second cat: cat(x2, x3) - x2 already has a constraint from cat1,
1025+ # so this cat cannot be optimized
1026+ pre_created_output2 = builder .call_operator (
1027+ op = exir_ops .edge .aten .full .default ,
1028+ args = ([16 , 8 ], 0.0 ),
1029+ kwargs = {"dtype" : torch .float32 },
1030+ )
1031+ cat2 = builder .call_operator (
1032+ op = torch .ops .aten .cat .out ,
1033+ args = ([x2 , x3 ],),
1034+ kwargs = {"dim" : 0 , "out" : pre_created_output2 },
1035+ )
1036+ # Use both cat results to keep them alive
1037+ graph_output = builder .call_operator (
1038+ op = exir_ops .edge .aten .add .Tensor ,
1039+ args = (cat1 , cat2 ),
1040+ )
1041+ builder .output ([graph_output ])
1042+ original = builder .get_graph_module ()
1043+ graph_module = self .run_memory_planning (
1044+ original , opt_level = 3 , alloc_graph_input = False
1045+ )
1046+ graph_module .graph .eliminate_dead_code ()
1047+
1048+ # The first cat should be optimized to _cat_nop, but the second cat
1049+ # cannot be optimized because x2 already has a relative placement constraint
1050+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
1051+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
1052+ self .verify_nop_memory_alloc (graph_module )
1053+
9501054 def test_view_for_unallocated_output (self ) -> None :
9511055 builder = GraphBuilder ()
9521056 x = builder .placeholder ("x" , torch .ones (3 , 5 , dtype = torch .float32 ))
0 commit comments