@@ -43,6 +43,78 @@ def setUp(self):
4343 )
4444 dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
4545
46+ def run_tester (self , module , inputs ):
47+ tester = Tester (
48+ module .eval (),
49+ inputs ,
50+ )
51+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
52+
53+ class LinearConv (torch .nn .Module ):
54+ def __init__ (self ):
55+ super ().__init__ ()
56+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
57+ self .linear1 = torch .nn .Linear (4 , 3 )
58+
59+ def forward (self , x ):
60+ y = self .linear1 (x )
61+ return self .conv1 (y )
62+
63+ LinearConvModule = LinearConv ()
64+
65+ class ConvLinearConv (torch .nn .Module ):
66+ def __init__ (self ):
67+ super ().__init__ ()
68+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
69+ self .linear1 = torch .nn .Linear (4 , 4 )
70+
71+ def forward (self , x ):
72+ y = self .conv1 (x )
73+ return self .linear1 (y )
74+
75+ ConvLinearConvModule = ConvLinearConv ()
76+
77+ class Bilinear (torch .nn .Module ):
78+ def __init__ (self ):
79+ super ().__init__ ()
80+
81+ def forward (self , x ):
82+ return torch .nn .functional .interpolate (
83+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
84+ )
85+
86+ BilinearModule = Bilinear ()
87+
88+ def test_conv_linear_dim_order_swaps (self ):
89+ self .run_tester (self .LinearConvModule , (torch .randn (1 , 3 , 6 , 4 ),))
90+ self .run_tester (
91+ self .LinearConvModule ,
92+ (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),),
93+ )
94+
95+ def test_linear_conv_dim_order_swaps (self ):
96+ self .run_tester (self .ConvLinearConvModule , (torch .randn (1 , 3 , 6 , 6 ),))
97+ self .run_tester (
98+ self .ConvLinearConvModule ,
99+ (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),),
100+ )
101+
102+ def test_nhwc_nchw_input_on_nhwc_op (self ):
103+ self .run_tester (
104+ self .BilinearModule ,
105+ (
106+ torch .arange (8 )
107+ .reshape (1 , 2 , 2 , 2 )
108+ .to (torch .float32 )
109+ .to (memory_format = torch .channels_last ),
110+ ),
111+ )
112+
113+ self .run_tester (
114+ self .BilinearModule ,
115+ (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),),
116+ )
117+
46118 def test_fp32_channels_last_tagged_reshape_pass (self ):
47119 for module , num_reshape in self .modules .items ():
48120 (
@@ -58,6 +130,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
58130 .run_method_and_compare_outputs ()
59131 )
60132
133+ class LinearConv (torch .nn .Module ):
134+ def __init__ (self ):
135+ super ().__init__ ()
136+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
137+ self .linear1 = torch .nn .Linear (4 , 3 )
138+
139+ def forward (self , x ):
140+ y = self .linear1 (x )
141+ return self .conv1 (y )
142+
143+ def test_conv_linear_dim_order_swaps_on_nhwc_input (self ):
144+ tester = Tester (
145+ self .LinearConv ().eval (),
146+ (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),),
147+ )
148+
149+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
150+
151+ def test_conv_linear_dim_order_swaps_on_nchw_input (self ):
152+ tester = Tester (
153+ self .LinearConv ().eval (),
154+ (torch .randn (1 , 3 , 6 , 4 ),),
155+ )
156+
157+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
158+
159+ class ConvLinearConv (torch .nn .Module ):
160+ def __init__ (self ):
161+ super ().__init__ ()
162+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
163+ self .linear1 = torch .nn .Linear (4 , 4 )
164+
165+ def forward (self , x ):
166+ y = self .conv1 (x )
167+ return self .linear1 (y )
168+
169+ def test_linear_conv_dim_order_swaps_on_nhwc_input (self ):
170+ tester = Tester (
171+ self .ConvLinearConv ().eval (),
172+ (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),),
173+ )
174+
175+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
176+
177+ def test_linear_conv_dim_order_swaps_on_nchw_input (self ):
178+ tester = Tester (
179+ self .ConvLinearConv ().eval (),
180+ (torch .randn (1 , 3 , 6 , 6 ),),
181+ )
182+
183+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
184+
185+ class Bilinear (torch .nn .Module ):
186+ def __init__ (self ):
187+ super ().__init__ ()
188+
189+ def forward (self , x ):
190+ return torch .nn .functional .interpolate (
191+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
192+ )
193+
194+ def test_nhwc_input_on_nhwc_op (self ):
195+ tester = Tester (
196+ self .Bilinear ().eval (),
197+ (
198+ torch .arange (8 )
199+ .reshape (1 , 2 , 2 , 2 )
200+ .to (torch .float32 )
201+ .to (memory_format = torch .channels_last ),
202+ ),
203+ )
204+
205+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
206+
207+ def test_nchw_input_on_nhwc_op (self ):
208+ tester = Tester (
209+ self .Bilinear ().eval (),
210+ (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),),
211+ )
212+
213+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
214+
61215 def test_qs8_channels_last_tagged_reshape_pass (self ):
62216 for module , num_reshape in self .modules .items ():
63217 (
0 commit comments