11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
@@ -36,8 +35,8 @@ def forward(self, x: torch.Tensor):
3635 return torch .full ((2 , 2 , 3 , 3 ), 4.5 , dtype = torch .float32 ) + x
3736
3837 class AddVariableFull (torch .nn .Module ):
39- sizes = [
40- (5 ),
38+ sizes : list [ tuple [ int , ...]] = [
39+ (5 , ),
4140 (5 , 5 ),
4241 (5 , 5 , 5 ),
4342 (1 , 5 , 5 , 5 ),
@@ -48,6 +47,21 @@ def forward(self, x: torch.Tensor, y):
4847 # Input + a full with the shape from the input and a given value 'y'.
4948 return x + torch .full (x .shape , y )
5049
50+ class FullLike (torch .nn .Module ):
51+ """Since full_like is replaced with full, we only need to test on reference model, not FVP."""
52+
53+ test_parameters = [
54+ ((torch .randn (2 , 2 , 2 , 2 ) * 50 , 3.2 ),),
55+ ((torch .randn (2 , 2 , 2 , 2 ) * 50 , 3 ),),
56+ (((torch .randn (2 , 2 , 2 , 2 ) * 50 ).to (torch .int32 ), 3.2 ),),
57+ (((torch .randn (2 , 2 , 2 , 2 ) * 50 ).to (torch .int32 ), 3 ),),
58+ ]
59+
60+ def forward (self , input_tensor : torch .Tensor , value ):
61+ # Our backend can't handle tensors without users, which input_tensor doesn't have
62+ # when the full_like is converted to a full. Therefore involve it in the output.
63+ return input_tensor + torch .full_like (input_tensor , value )
64+
5165 def _test_full_tosa_MI_pipeline (
5266 self ,
5367 module : torch .nn .Module ,
@@ -63,9 +77,7 @@ def _test_full_tosa_MI_pipeline(
6377 compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
6478 )
6579 .export ()
66- .check_count ({"torch.ops.aten.full.default" : 1 })
67- .to_edge ()
68- .partition ()
80+ .to_edge_transform_and_lower ()
6981 .check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
7082 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7183 .to_executorch ()
@@ -85,9 +97,7 @@ def _test_full_tosa_BI_pipeline(
8597 )
8698 .quantize ()
8799 .export ()
88- .check_count ({"torch.ops.aten.full.default" : 1 })
89- .to_edge ()
90- .partition ()
100+ .to_edge_transform_and_lower ()
91101 .check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
92102 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
93103 .to_executorch ()
@@ -101,9 +111,7 @@ def _test_full_tosa_ethos_pipeline(
101111 ArmTester (module , example_inputs = test_data , compile_spec = compile_spec )
102112 .quantize ()
103113 .export ()
104- .check_count ({"torch.ops.aten.full.default" : 1 })
105- .to_edge ()
106- .partition ()
114+ .to_edge_transform_and_lower ()
107115 .check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
108116 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
109117 .to_executorch ()
@@ -129,6 +137,10 @@ def test_const_full_tosa_MI(self):
129137 _input = torch .rand ((2 , 2 , 3 , 3 )) * 10
130138 self ._test_full_tosa_MI_pipeline (self .AddConstFull (), (_input ,))
131139
140+ @parameterized .expand (FullLike .test_parameters )
141+ def test_full_like_tosa_MI (self , test_tensor : Tuple ):
142+ self ._test_full_tosa_MI_pipeline (self .FullLike (), test_tensor )
143+
132144 def test_const_full_nhwc_tosa_BI (self ):
133145 _input = torch .rand ((2 , 2 , 3 , 3 )) * 10
134146 self ._test_full_tosa_BI_pipeline (self .AddConstFull (), (_input ,))
@@ -143,6 +155,10 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
143155 def test_full_tosa_BI (self , test_tensor : Tuple ):
144156 self ._test_full_tosa_BI_pipeline (self .AddVariableFull (), test_tensor )
145157
158+ @parameterized .expand (FullLike .test_parameters )
159+ def test_full_like_tosa_BI (self , test_tensor : Tuple ):
160+ self ._test_full_tosa_BI_pipeline (self .FullLike (), test_tensor )
161+
146162 @parameterized .expand (AddVariableFull .test_parameters )
147163 @pytest .mark .corstone_fvp
148164 def test_full_u55_BI (self , test_tensor : Tuple ):
0 commit comments