@@ -130,87 +130,22 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
130130 .run_method_and_compare_outputs ()
131131 )
132132
133- class LinearConv (torch .nn .Module ):
133+ class LinearConvDimSwap (torch .nn .Module ):
134134 def __init__ (self ):
135135 super ().__init__ ()
136136 self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
137137 self .linear1 = torch .nn .Linear (4 , 3 )
138138
139139 def forward (self , x ):
140140 y = self .linear1 (x )
141+ y = y .to (memory_format = torch .channels_last )
142+ y = y .to (memory_format = torch .contiguous_format )
141143 return self .conv1 (y )
142144
143- def test_conv_linear_dim_order_swaps_on_nhwc_input (self ):
144- tester = Tester (
145- self .LinearConv ().eval (),
146- (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),),
147- )
148-
149- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
145+ LinearConvDimSwapModule = LinearConvDimSwap ()
150146
151- def test_conv_linear_dim_order_swaps_on_nchw_input (self ):
152- tester = Tester (
153- self .LinearConv ().eval (),
154- (torch .randn (1 , 3 , 6 , 4 ),),
155- )
156-
157- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
158-
159- class ConvLinearConv (torch .nn .Module ):
160- def __init__ (self ):
161- super ().__init__ ()
162- self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
163- self .linear1 = torch .nn .Linear (4 , 4 )
164-
165- def forward (self , x ):
166- y = self .conv1 (x )
167- return self .linear1 (y )
168-
169- def test_linear_conv_dim_order_swaps_on_nhwc_input (self ):
170- tester = Tester (
171- self .ConvLinearConv ().eval (),
172- (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),),
173- )
174-
175- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
176-
177- def test_linear_conv_dim_order_swaps_on_nchw_input (self ):
178- tester = Tester (
179- self .ConvLinearConv ().eval (),
180- (torch .randn (1 , 3 , 6 , 6 ),),
181- )
182-
183- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
184-
185- class Bilinear (torch .nn .Module ):
186- def __init__ (self ):
187- super ().__init__ ()
188-
189- def forward (self , x ):
190- return torch .nn .functional .interpolate (
191- x , scale_factor = 2 , mode = "bilinear" , align_corners = True
192- )
193-
194- def test_nhwc_input_on_nhwc_op (self ):
195- tester = Tester (
196- self .Bilinear ().eval (),
197- (
198- torch .arange (8 )
199- .reshape (1 , 2 , 2 , 2 )
200- .to (torch .float32 )
201- .to (memory_format = torch .channels_last ),
202- ),
203- )
204-
205- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
206-
207- def test_nchw_input_on_nhwc_op (self ):
208- tester = Tester (
209- self .Bilinear ().eval (),
210- (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),),
211- )
212-
213- tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
147+ def test_conv_linear_dim_order_swap_partitioner (self ):
148+ self .run_tester (self .LinearConvDimSwapModule , (torch .randn (1 , 3 , 6 , 4 ),))
214149
215150 def test_qs8_channels_last_tagged_reshape_pass (self ):
216151 for module , num_reshape in self .modules .items ():
0 commit comments